Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1aed5cb0
Commit
1aed5cb0
authored
Feb 15, 2023
by
guangzlu
Browse files
added group fwd mha dropout verify
parent
cdc6f6ba
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
65 deletions
+82
-65
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+5
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+63
-29
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+8
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+6
-30
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
View file @
1aed5cb0
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -161,6 +162,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -161,6 +162,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
ADataType
,
ADataType
>
;
#include "run_grouped_multihead_attention_forward.inc"
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
1aed5cb0
...
@@ -5,12 +5,12 @@ int run(int argc, char* argv[])
...
@@ -5,12 +5,12 @@ int run(int argc, char* argv[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.
2
;
float
p_drop
=
0.
1
;
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
...
@@ -45,9 +45,9 @@ int run(int argc, char* argv[])
...
@@ -45,9 +45,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
alpha
=
1
;
// scaling after 1st gemm
float
alpha
=
0.25
;
// scaling after 1st gemm
std
::
size_t
group_count
=
7
;
std
::
size_t
group_count
=
8
;
// Problem descs
// Problem descs
std
::
vector
<
DeviceGemmInstance
::
ProblemDesc
>
problem_descs
;
std
::
vector
<
DeviceGemmInstance
::
ProblemDesc
>
problem_descs
;
...
@@ -79,13 +79,24 @@ int run(int argc, char* argv[])
...
@@ -79,13 +79,24 @@ int run(int argc, char* argv[])
std
::
cout
<<
"group count "
<<
group_count
<<
". printing first 4 groups
\n
"
;
std
::
cout
<<
"group count "
<<
group_count
<<
". printing first 4 groups
\n
"
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
M
=
512
;
int
N
=
512
;
int
K
=
40
;
int
K
=
40
;
int
O
=
40
*
(
rand
()
%
2
+
1
)
;
int
O
=
40
;
int
G0
=
rand
()
%
3
+
1
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
G1
=
rand
()
%
5
+
1
;
// int M = 128 * (rand() % 8 + 1);
// int N = 128 * (rand() % 8 + 1);
// int K = 40;
// int O = 40 * (rand() % 2 + 1);
// int G0 = rand() % 3 + 1;
// int G1 = rand() % 5 + 1;
std
::
cout
<<
"group id"
<<
i
<<
" M, N, K, O, G0, G1 is "
<<
M
<<
","
<<
N
<<
","
<<
K
<<
","
<<
O
<<
","
<<
G0
<<
","
<<
G1
<<
std
::
endl
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
@@ -229,25 +240,26 @@ int run(int argc, char* argv[])
...
@@ -229,25 +240,26 @@ int run(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
auto
argument
=
p_b0
,
gemm
.
MakeArgument
(
p_a
,
p_b1
,
p_b0
,
p_c
,
p_b1
,
p_z
,
p_c
,
p_lse
,
p_z
,
{},
// p_acc0_biases
p_lse
,
{},
// p_acc1_biases
{},
// p_acc0_biases
problem_descs
,
{},
// p_acc1_biases
a_element_op
,
problem_descs
,
b0_element_op
,
a_element_op
,
acc0_element_op
,
b0_element_op
,
b1_element_op
,
acc0_element_op
,
c_element_op
,
b1_element_op
,
p_drop
,
// dropout ratio
c_element_op
,
{
0
,
448
});
// dropout random seed and offset, offset should be
p_drop
,
// dropout ratio
// at least the number of elements on a thread
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
// specify workspace for problem_desc
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -291,11 +303,14 @@ int run(int argc, char* argv[])
...
@@ -291,11 +303,14 @@ int run(int argc, char* argv[])
const
auto
&
b0_gs_ns_ks
=
b0_tensors
[
i
];
const
auto
&
b0_gs_ns_ks
=
b0_tensors
[
i
];
const
auto
&
b1_gs_os_ns
=
b1_tensors
[
i
];
const
auto
&
b1_gs_os_ns
=
b1_tensors
[
i
];
auto
&
c_gs_ms_os_device_result
=
c_tensors
[
i
];
auto
&
c_gs_ms_os_device_result
=
c_tensors
[
i
];
auto
&
z_gs_ms_ns_device_result
=
z_tensors
[
i
];
auto
&
lse_gs_ms_device_result
=
lse_tensors
[
i
];
auto
&
lse_gs_ms_device_result
=
lse_tensors
[
i
];
auto
&
c_gs_ms_os_device_buf
=
*
c_tensors_device
[
i
];
auto
&
c_gs_ms_os_device_buf
=
*
c_tensors_device
[
i
];
auto
&
z_gs_ms_ns_device_buf
=
*
z_tensors_device
[
i
];
auto
&
lse_gs_ms_device_buf
=
*
lse_tensors_device
[
i
];
auto
&
lse_gs_ms_device_buf
=
*
lse_tensors_device
[
i
];
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
...
@@ -303,8 +318,11 @@ int run(int argc, char* argv[])
...
@@ -303,8 +318,11 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
// Tensor<CDataType> z_gs_ms_ns_host_result(z_gs_ms_os_lengths, z_gs_ms_os_strides);
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
...
@@ -319,6 +337,10 @@ int run(int argc, char* argv[])
...
@@ -319,6 +337,10 @@ int run(int argc, char* argv[])
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
// gemm 0
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
...
@@ -342,10 +364,20 @@ int run(int argc, char* argv[])
...
@@ -342,10 +364,20 @@ int run(int argc, char* argv[])
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// printf("print z_g_m_n \n");
// z_g_m_n.ForEach([&](auto& self, auto idx) {printf("%u ", self(idx));});
// dropout after softmax
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm 1
// gemm 1
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g_m_n
,
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g_m_n
_drop
,
b1_g_n_o
,
b1_g_n_o
,
c_g_m_o_host_result
,
c_g_m_o_host_result
,
PassThrough
{},
PassThrough
{},
...
@@ -384,9 +416,11 @@ int run(int argc, char* argv[])
...
@@ -384,9 +416,11 @@ int run(int argc, char* argv[])
atol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
}
printf
(
"group id is %lu
\n
"
,
i
);
// bool pass_ =
// bool pass_ =
//
ck::utils::check_err(c_gs_ms_os_device_result.mData,
// ck::utils::check_err(c_gs_ms_os_device_result.mData,
//
c_gs_ms_os_host_result.mData);
// c_gs_ms_os_host_result.mData);
bool
pass_
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
bool
pass_
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results c!"
,
"Error: Incorrect results c!"
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
1aed5cb0
...
@@ -97,17 +97,18 @@ __global__ void
...
@@ -97,17 +97,18 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
//
unsigned short* p_z_grid_in = //
unsigned
short
*
p_z_grid_in
=
//
//
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
//
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
p_z_grid_in
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
// arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -417,6 +418,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -417,6 +418,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
...
@@ -621,7 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -621,7 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n
);
z_grid_desc_m_n
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
1aed5cb0
...
@@ -139,23 +139,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -139,23 +139,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
__device__
static
auto
GetGemm0WaveIdx
()
{
{
...
@@ -852,7 +835,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -852,7 +835,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
///////////////////=>z for dropout
///////////////////=>z for dropout
//
//
// z vgpr copy to global
// z vgpr copy to global
//
//
...
@@ -876,11 +858,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -876,11 +858,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer
;
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_tenor_buffer
.
Clear
();
// z matrix global desc
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
...
@@ -1025,7 +1002,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1025,7 +1002,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
tru
e
>(
fals
e
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
...
@@ -1034,20 +1011,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1034,20 +1011,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
else
else
{
{
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
tru
e
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
fals
e
>(
acc_thread_buf
,
ph
);
acc_thread_buf
,
ph
);
}
}
}
}
// if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment