Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf80ceee
Unverified
Commit
bf80ceee
authored
Feb 12, 2023
by
guangzlu
Committed by
GitHub
Feb 12, 2023
Browse files
Merge branch 'attn-fwd-dropout-verify' into attn-bwd-dropout
parents
9dc05eaa
a2887630
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1256 additions
and
1 deletion
+1256
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+5
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+20
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle
...l/device_grouped_multihead_attention_forward_xdl_cshuffle
+1058
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+168
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
100755 → 100644
View file @
bf80ceee
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -78,6 +81,7 @@ using DeviceGemmInstance =
...
@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
...
@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc"
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
bf80ceee
...
@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
...
@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_z
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
...
@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
...
@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
...
@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
...
@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
...
@@ -102,6 +105,12 @@ int run(int argc, char* argv[])
...
@@ -102,6 +105,12 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
...
@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
...
@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
c_gs_ms_os_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
lse_gs_ms_strides
,
{},
// acc0_biases_gs_ms_ns_lengths
{},
// acc0_biases_gs_ms_ns_lengths
...
@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
...
@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
;
...
@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
...
@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
<<
std
::
endl
;
}
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
...
@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
...
@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
...
@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
...
@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
()));
...
@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
...
@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
}
...
@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
...
@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
p_b0
,
p_b0
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_z
,
p_lse
,
p_lse
,
{},
// p_acc0_biases
{},
// p_acc0_biases
{},
// p_acc1_biases
{},
// p_acc1_biases
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
bf80ceee
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
...
@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
z_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
z_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
...
@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle
0 → 100644
View file @
bf80ceee
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
bf80ceee
...
@@ -35,6 +35,7 @@ template <typename FloatAB,
...
@@ -35,6 +35,7 @@ template <typename FloatAB,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
LSEGridDesc_M
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
// Gemm0
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
...
@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
///////////////////=>z for dropout
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
@@ -876,8 +1017,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -876,8 +1017,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
if
constexpr
(
IsDropout
)
// dropout
{
{
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
ph
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
true
>(
acc_thread_buf
,
ph
);
}
}
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_max_new
=
mathext
::
max
(
max
,
running_max
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment