Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cb2d4dbb
Commit
cb2d4dbb
authored
Feb 10, 2023
by
ltqin
Browse files
Merge branch 'attn-bwd-dropout' into attn-fwd-train-dropout
parents
989e3d10
0e7aeef5
Changes
29
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4692 additions
and
102 deletions
+4692
-102
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+6
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+810
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_bf16.cpp
...softmax_gemm/batched_multihead_attention_forward_bf16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_fp16.cpp
...softmax_gemm/batched_multihead_attention_forward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
...softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+0
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+0
-0
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+41
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+137
-2
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+18
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+1256
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+25
-25
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+52
-52
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+2316
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+14
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
cb2d4dbb
...
@@ -3,12 +3,14 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_
...
@@ -3,12 +3,14 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
0 → 100644
View file @
cb2d4dbb
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/batched_
gemm_scale_softmax_gemm_permute_train_xdl
_bf16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_
multihead_attention_forward
_bf16.cpp
View file @
cb2d4dbb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_batched_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_batched_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_
gemm_scale_softmax_gemm_permute_train_xdl
_fp16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_
multihead_attention_forward
_fp16.cpp
View file @
cb2d4dbb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_batched_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_batched_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_
gemm_scale_softmax_gemm_permute_train_xdl
_bf16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/grouped_
multihead_attention_forward
_bf16.cpp
View file @
cb2d4dbb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_grouped_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_grouped_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_
gemm_scale_softmax_gemm_permute_train_xdl
_fp16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/grouped_
multihead_attention_forward
_fp16.cpp
View file @
cb2d4dbb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_grouped_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_grouped_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_batched_
gemm_scale_softmax_gemm_permute_train
.inc
→
example/32_batched_gemm_scale_softmax_gemm/run_batched_
multihead_attention_forward
.inc
View file @
cb2d4dbb
File moved
example/32_batched_gemm_scale_softmax_gemm/run_grouped_
gemm_scale_softmax_gemm_permute_train
.inc
→
example/32_batched_gemm_scale_softmax_gemm/run_grouped_
multihead_attention_forward
.inc
View file @
cb2d4dbb
File moved
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
cb2d4dbb
...
@@ -16,12 +16,15 @@ struct BlockwiseDropout
...
@@ -16,12 +16,15 @@ struct BlockwiseDropout
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
template
<
typename
CThreadBuffer
>
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
ph
)
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
ph
)
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
...
@@ -47,6 +50,42 @@ struct BlockwiseDropout
...
@@ -47,6 +50,42 @@ struct BlockwiseDropout
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
ph
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
ushort
p_dropout_16bits
;
ushort
p_dropout_16bits
;
DataType
p_dropout_rescale
;
DataType
p_dropout_rescale
;
};
};
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
cb2d4dbb
...
@@ -50,7 +50,8 @@ template <index_t BlockSize,
...
@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeC
=
false
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"
);
"wrong!"
);
}
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
...
@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
...
@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
__host__
__device__
static
constexpr
auto
MakeCThreadTileIterator
()
{
constexpr
auto
c_thread_lengths
=
conditional_expr
<
TransposeC
>
(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
());
return
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
{};
// SnakeCurved
}
__host__
__device__
static
constexpr
auto
MakeCThreadIndexAdaptor8DTo2D
()
{
if
constexpr
(
TransposeC
)
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
n3
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n4
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
,
n3
,
n4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
}
else
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m3
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
m4
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
,
m4
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
@@ -905,6 +988,58 @@ struct BlockwiseGemmXdlops_v2
...
@@ -905,6 +988,58 @@ struct BlockwiseGemmXdlops_v2
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
__host__
__device__
static
constexpr
auto
MakeCThreadTileIterator
()
{
constexpr
auto
c_thread_lengths
=
conditional_expr
<
TransposeC
>
(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
());
return
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
{};
// SnakeCurved
}
__host__
__device__
static
constexpr
auto
MakeCThreadIndexAdaptor8DTo2D
()
{
if
constexpr
(
TransposeC
)
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
n3
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n4
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
,
n3
,
n4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
}
else
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m3
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
m4
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
,
m4
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
cb2d4dbb
...
@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
...
@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
LSEBuffer
>
__host__
__device__
void
RunWithPreCalcStats
(
CThreadBuffer
&
in_thread_buf
,
const
LSEBuffer
&
lse_thread_buf
)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
lse_thread_buf
[
iM
]);
});
});
}
BufferType
max_value_buf
;
BufferType
max_value_buf
;
BufferType
sum_value_buf
;
BufferType
sum_value_buf
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
cb2d4dbb
...
@@ -84,7 +84,7 @@ template <index_t NumDimG,
...
@@ -84,7 +84,7 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatched
GemmSoftmaxGemmPermuteTrain
:
public
BaseOperator
struct
DeviceBatched
MultiheadAttentionForward
:
public
BaseOperator
{
{
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
cb2d4dbb
...
@@ -88,7 +88,7 @@ template <index_t NumDimG,
...
@@ -88,7 +88,7 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceGrouped
GemmSoftmaxGemmPermuteTrain
:
public
BaseOperator
struct
DeviceGrouped
MultiheadAttentionForward
:
public
BaseOperator
{
{
struct
ProblemDesc
struct
ProblemDesc
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
0 → 100644
View file @
cb2d4dbb
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_forward
_xdl_cshuffle.hpp
View file @
cb2d4dbb
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -47,7 +47,7 @@ __global__ void
...
@@ -47,7 +47,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
(
kernel_batched_
multiheadattention_forward
_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
@@ -205,25 +205,25 @@ template <index_t NumDimG,
...
@@ -205,25 +205,25 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
struct
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
:
public
DeviceBatched
GemmSoftmaxGemmPermuteTrain
<
NumDimG
,
:
public
DeviceBatched
MultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
MaskingSpec
>
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
...
@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatched
GemmSoftmaxGemmTrain
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
<
const
auto
kernel
=
kernel_batched_
multiheadattention_forward
_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle"
str
<<
"DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_
multihead_attention_forward
_xdl_cshuffle.hpp
View file @
cb2d4dbb
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -37,7 +37,7 @@ __global__ void
...
@@ -37,7 +37,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_
gemm_softmax_gemm
_xdl_cshuffle
_v2
(
kernel_grouped_
multiheadattention_forward
_xdl_cshuffle
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -197,25 +197,25 @@ template <index_t NumDimG,
...
@@ -197,25 +197,25 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
struct
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
:
public
DeviceGrouped
GemmSoftmaxGemmPermuteTrain
<
NumDimG
,
:
public
DeviceGrouped
MultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
MaskingSpec
>
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
...
@@ -236,25 +236,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -236,25 +236,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
;
using
ProblemDesc
=
typename
DeviceGrouped
GemmSoftmaxGemmPermuteTrain
<
NumDimG
,
using
ProblemDesc
=
typename
DeviceGrouped
MultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>::
ProblemDesc
;
MaskingSpec
>::
ProblemDesc
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatched
GemmSoftmaxGemmTrain
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -705,16 +705,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -705,16 +705,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_
gemm_softmax_gemm
_xdl_cshuffle
_v2
<
GridwiseGemm
,
kernel_grouped_
multiheadattention_forward
_xdl_cshuffle
<
GridwiseGemm
,
GemmAccDataType
,
GemmAccDataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
>
;
is_dropout_
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle"
str
<<
"DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
cb2d4dbb
...
@@ -95,6 +95,8 @@ struct Scale
...
@@ -95,6 +95,8 @@ struct Scale
y
=
scale_
*
x
;
y
=
scale_
*
x
;
};
};
__host__
__device__
void
Append
(
float
scale
)
{
scale_
=
scale_
*
scale
;
}
float
scale_
;
float
scale_
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
0 → 100644
View file @
cb2d4dbb
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_forward
_xdl_cshuffle.hpp
View file @
cb2d4dbb
...
@@ -83,7 +83,7 @@ template <typename FloatAB,
...
@@ -83,7 +83,7 @@ template <typename FloatAB,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatched
GemmSoftmaxGemmTrain
_Xdl_CShuffle
struct
GridwiseBatched
MultiheadAttentionForward
_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
cb2d4dbb
...
@@ -143,6 +143,16 @@ struct DynamicBuffer
...
@@ -143,6 +143,16 @@ struct DynamicBuffer
}
}
}
}
__host__
__device__
void
Clear
()
{
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong! only local data share is supported"
);
for
(
index_t
i
=
get_thread_local_1d_id
();
i
<
element_space_size_
;
i
+=
get_block_size
())
{
Set
(
i
,
true
,
T
{
0
});
}
}
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
...
@@ -302,7 +312,9 @@ struct DynamicBuffer
...
@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
"wrong! X should contain multiple T"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"only support global mem"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
||
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"only support global mem or local data share"
);
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
bool
constexpr
use_amd_buffer_addressing
=
...
@@ -319,7 +331,7 @@ struct DynamicBuffer
...
@@ -319,7 +331,7 @@ struct DynamicBuffer
bool
constexpr
use_amd_buffer_addressing
=
false
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
#endif
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
&&
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
)
{
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment