Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a5bad9f2
Commit
a5bad9f2
authored
Feb 17, 2023
by
danyao12
Browse files
aligned with prototype2 dropout
parent
36dc18e8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
41 deletions
+48
-41
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+11
-11
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+37
-30
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
a5bad9f2
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
#endif
kernel_batched_multiheadattention_backward_xdl_cshuffle_pt1
(
kernel_batched_multihead
_
attention_backward_xdl_cshuffle_pt1
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -83,7 +83,7 @@ __global__ void
...
@@ -83,7 +83,7 @@ __global__ void
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
out
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
...
@@ -138,7 +138,7 @@ __global__ void
...
@@ -138,7 +138,7 @@ __global__ void
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
out
,
p_drop
,
ph
);
ph
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
...
@@ -158,6 +158,9 @@ __global__ void
...
@@ -158,6 +158,9 @@ __global__ void
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
ignore
=
p_drop
;
ignore
=
seed
;
ignore
=
offset
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -758,7 +761,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -758,7 +761,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
z_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc0_biases
;
...
@@ -779,10 +783,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -779,10 +783,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
y_grid_desc_m_o_
);
y_grid_desc_m_o_
);
}
}
p_dropout_
=
1.
f
-
p_drop
;
float
rp_dropout_
=
1.
f
/
p_dropout_
;
acc_element_op_
.
Append
(
rp_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -873,7 +873,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -873,7 +873,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop
out
_
;
float
p_drop_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
unsigned
long
long
offset_
;
};
};
...
@@ -896,7 +896,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -896,7 +896,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multiheadattention_backward_xdl_cshuffle_pt1
<
const
auto
kernel
=
kernel_batched_multihead
_
attention_backward_xdl_cshuffle_pt1
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
DataType
,
ZDataType
,
ZDataType
,
...
@@ -951,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -951,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop
out
_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
);
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
a5bad9f2
...
@@ -1259,11 +1259,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1259,11 +1259,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
FloatGemmAcc
p_drop
out
,
const
float
p_drop
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
FloatGemmAcc
rp_dropout
=
1.0
f
/
p_dropout
;
const
bool
is_dropout
=
p_drop
>
0.0
f
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
@@ -1670,9 +1674,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1670,9 +1674,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
s
_element_op
)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
decltype
(
s
cale_rp_dropout
)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
s_element_op
);
scale_rp_dropout
);
//
//
// set up Y dot dY
// set up Y dot dY
...
@@ -1871,9 +1875,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1871,9 +1875,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
const
index_t
K
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
float
scalar
=
1.0
f
/
std
::
sqrt
(
K
);
// Initialize dQ
// Initialize dQ
qgrad_thread_buf
.
Clear
();
qgrad_thread_buf
.
Clear
();
...
@@ -1966,14 +1967,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1966,14 +1967,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}
}
else
else
{
{
s_slash_p_thread_buf
(
i
)
=
scalar
*
s_slash_p_thread_buf
[
i
];
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
}
});
});
}
}
else
else
{
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
[
&
](
auto
i
)
{
s_slash_p_thread_buf
(
i
)
=
scalar
*
s_slash_p_thread_buf
[
i
];
});
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
});
}
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
@@ -1983,6 +1986,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1983,6 +1986,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
// save z to global
if
(
is_dropout
)
{
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
// P_dropped
// P_dropped
...
@@ -1991,7 +1996,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1991,7 +1996,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -2003,6 +2009,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -2003,6 +2009,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
s_slash_p_thread_buf
,
ph
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
@@ -2306,7 +2313,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -2306,7 +2313,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
n_thread_data_on_block_idx
[
I4
]),
s
_element_op
};
s
cale_rp_dropout
};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment