Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b03f56db
Commit
b03f56db
authored
Feb 03, 2023
by
danyao12
Browse files
add SetA/BBlockStartWindow in BlockwiseGemmXdlops_v2
parent
fced127d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
17 deletions
+12
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+8
-14
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+4
-3
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
100644 → 100755
View file @
b03f56db
...
@@ -859,24 +859,19 @@ struct BlockwiseGemmXdlops_v2
...
@@ -859,24 +859,19 @@ struct BlockwiseGemmXdlops_v2
"wrong!"
);
"wrong!"
);
}
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
index_t
switch_flag
,
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
(),
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
())
:
switch_flag_
(
switch_flag
),
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
}
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
__device__
void
SetABlockStartWindow
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
())
"wrong!"
);
{
a_thread_copy_
.
SetSrcCoord
(
a_origin
);
}
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
__device__
void
SetBBlockStartWindow
(
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
{
b_thread_copy_
.
SetSrcCoord
(
b_origin
);
}
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
...
@@ -1141,7 +1136,6 @@ struct BlockwiseGemmXdlops_v2
...
@@ -1141,7 +1136,6 @@ struct BlockwiseGemmXdlops_v2
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
index_t
switch_flag_
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
b03f56db
...
@@ -1333,8 +1333,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1333,8 +1333,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ: blockwise gemm
// dQ: blockwise gemm
auto
qgrad_blockwise_gemm
=
auto
qgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{};
typename
Gemm1
::
B
lockwise
G
emm
{
make_tuple
(
0
,
0
,
0
,
0
),
make_tuple
(
0
,
0
,
0
,
0
)
}
;
qgrad_b
lockwise
_g
emm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
)
)
;
// dQ: B matrix blockwise copy
// dQ: B matrix blockwise copy
auto
k_thread_origin
=
qgrad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
auto
k_thread_origin
=
qgrad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
...
@@ -1458,7 +1458,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1458,7 +1458,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
// dV: blockwise gemm
// dV: blockwise gemm
auto
v_slash_k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{
1
,
make_tuple
(
0
,
0
,
0
,
0
)};
auto
v_slash_k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
v_slash_k_grad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
auto
q_slash_ygrad_thread_origin
=
v_slash_k_grad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
auto
q_slash_ygrad_thread_origin
=
v_slash_k_grad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment