Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
77df3ccb
"...composable_kernel_rocm.git" did not exist on "85fc91c3218c1d85169ed1fe95eef7b07942e648"
Commit
77df3ccb
authored
Aug 17, 2023
by
letaoqin
Browse files
format
parent
48f98948
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
40 deletions
+42
-40
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+5
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+37
-36
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
77df3ccb
...
@@ -119,12 +119,13 @@ __global__ void
...
@@ -119,12 +119,13 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
){
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
@@ -186,7 +187,7 @@ __global__ void
...
@@ -186,7 +187,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
77df3ccb
...
@@ -1191,43 +1191,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1191,43 +1191,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
__device__
static
void
const
InputDataType
*
__restrict__
p_k_grid
,
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
const
D0DataType
*
__restrict__
p_d_grid
,
const
InputDataType
*
__restrict__
p_k_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
D0DataType
*
__restrict__
p_d_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
void
*
__restrict__
p_shared
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
&
a_element_op
,
void
*
__restrict__
p_shared
,
const
BElementwiseOperation
&
b_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
SElementwiseOperation
&
s_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
SElementwiseOperation
&
s_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
CElementwiseOperation
&
c_element_op
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
const
LSEGridDesc_M
&
lse_grid_desc_m
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
float
p_drop
,
const
C0MatrixMask
&
c0_matrix_mask
,
ck
::
philox
&
ph
,
const
float
p_drop
,
const
index_t
z_random_matrix_offset
,
ck
::
philox
&
ph
,
const
index_t
raw_n_padded
,
const
index_t
z_random_matrix_offset
,
const
index_t
block_idx_n
)
const
index_t
raw_n_padded
,
const
index_t
block_idx_n
)
{
{
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
p_d_grid
;
ignore
=
p_d_grid
;
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
const
ushort
p_dropout_in_16bits
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment