Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c0c52268
Unverified
Commit
c0c52268
authored
Sep 22, 2023
by
Dan Yao
Committed by
GitHub
Sep 22, 2023
Browse files
Merge pull request #905 from ROCmSoftwarePlatform/mha-train-develop-grad-bias
flash attention output bias grad
parents
f04ec574
c88d1173
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
201 additions
and
69 deletions
+201
-69
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+178
-69
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+23
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
c0c52268
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
c0c52268
...
...
@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
//
// C0
//
static
auto
MakeC0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimN
,
CSpec
>
(
c_gs_ms_ns_lengths_vec
,
c_gs_ms_ns_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeC0GridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
return
MakeC0GridDescriptorPair
(
c_gs_ms_ns_lengths_vec
,
c_gs_ms_ns_strides_vec
).
first
;
}
static
auto
MakeC0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides_vec
)
{
return
matrix_padder
.
PadC0Descriptor_M_N
(
MakeC0GridDescriptorPair
(
c_gs_ms_ns_lengths_vec
,
c_gs_ms_ns_strides_vec
).
second
);
}
};
}
// namespace tensor_operation
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment