Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7d6a8ec7
Commit
7d6a8ec7
authored
Jun 19, 2023
by
guangzlu
Browse files
added dropout to fwd_v2 and bwd_qoop
parent
6d63c311
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
9 deletions
+77
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+27
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+27
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+23
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
7d6a8ec7
...
@@ -132,6 +132,9 @@ __global__ void
...
@@ -132,6 +132,9 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
i
);
i
);
}
}
}
}
...
@@ -165,6 +168,9 @@ __global__ void
...
@@ -165,6 +168,9 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
0
);
0
);
}
}
#else
#else
...
@@ -654,7 +660,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -654,7 +660,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
...
@@ -667,6 +673,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -667,6 +673,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// check C0 masking and padding
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
C0MatrixMask
c0_matrix_mask_
;
index_t
block_start_
,
block_end_
;
index_t
block_start_
,
block_end_
;
index_t
z_random_matrix_offset_
;
index_t
raw_m_padded_
,
raw_n_padded_
;
};
};
struct
GroupDeviceArg
struct
GroupDeviceArg
...
@@ -740,6 +749,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -740,6 +749,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
grid_size_
=
0
;
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -787,7 +799,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -787,7 +799,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
...
@@ -802,7 +814,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -802,7 +814,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n
);
z_grid_desc_m_n
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
...
@@ -836,6 +848,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -836,6 +848,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
"match that in template argument"
);
"match that in template argument"
);
}
}
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
]);
group_kernel_args_
.
push_back
({
p_a_grid
,
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_z_grid
,
p_z_grid
,
...
@@ -861,7 +878,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -861,7 +878,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
compute_base_ptr_of_batch
,
compute_base_ptr_of_batch
,
c0_matrix_mask
,
c0_matrix_mask
,
BlockStart
,
BlockStart
,
BlockEnd
});
BlockEnd
,
z_random_matrix_offset
,
raw_m_padded
,
raw_n_padded
});
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
group_device_args_
.
push_back
(
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
7d6a8ec7
...
@@ -132,6 +132,9 @@ __global__ void
...
@@ -132,6 +132,9 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
i
);
i
);
}
}
}
}
...
@@ -165,6 +168,9 @@ __global__ void
...
@@ -165,6 +168,9 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
0
);
0
);
}
}
#else
#else
...
@@ -662,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -662,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
...
@@ -675,6 +681,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -675,6 +681,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// check C0 masking and padding
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
C0MatrixMask
c0_matrix_mask_
;
index_t
block_start_
,
block_end_
;
index_t
block_start_
,
block_end_
;
index_t
z_random_matrix_offset_
;
index_t
raw_m_padded_
,
raw_n_padded_
;
};
};
struct
GroupDeviceArg
struct
GroupDeviceArg
...
@@ -748,6 +757,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -748,6 +757,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
grid_size_
=
0
;
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -795,7 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -795,7 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
...
@@ -810,7 +822,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -810,7 +822,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n
);
z_grid_desc_m_n
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
...
@@ -844,6 +856,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -844,6 +856,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
"match that in template argument"
);
"match that in template argument"
);
}
}
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
]);
group_kernel_args_
.
push_back
({
p_a_grid
,
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_z_grid
,
p_z_grid
,
...
@@ -869,7 +886,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -869,7 +886,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
compute_base_ptr_of_batch
,
compute_base_ptr_of_batch
,
c0_matrix_mask
,
c0_matrix_mask
,
BlockStart
,
BlockStart
,
BlockEnd
});
BlockEnd
,
z_random_matrix_offset
,
raw_m_padded
,
raw_n_padded
});
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
group_device_args_
.
push_back
(
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7d6a8ec7
...
@@ -132,6 +132,9 @@ __global__ void
...
@@ -132,6 +132,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
i
);
i
);
}
}
}
}
...
@@ -165,6 +168,9 @@ __global__ void
...
@@ -165,6 +168,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
0
);
0
);
}
}
#else
#else
...
@@ -567,6 +573,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -567,6 +573,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
index_t
block_start_
,
block_end_
;
index_t
z_random_matrix_offset_
;
index_t
raw_m_padded_
,
raw_n_padded_
;
};
};
struct
GroupDeviceArg
struct
GroupDeviceArg
...
@@ -626,6 +635,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -626,6 +635,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
grid_size_
=
0
;
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
...
@@ -712,6 +723,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -712,6 +723,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
"match that in template argument"
);
"match that in template argument"
);
}
}
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
]);
group_kernel_args_
.
push_back
({
p_a_grid
,
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
...
@@ -730,7 +746,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -730,7 +746,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
c0_matrix_mask
,
c0_matrix_mask
,
block_2_ctile_map
,
block_2_ctile_map
,
BlockStart
,
BlockStart
,
BlockEnd
});
BlockEnd
,
z_random_matrix_offset
,
raw_m_padded
,
raw_n_padded
});
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
group_device_args_
.
push_back
(
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment