Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fd37b2cf
"scripts/git@developer.sourcefind.cn:change/sglang.git" did not exist on "88bb627d0d224ad4195cc068cdca30f0b3634b48"
Commit
fd37b2cf
authored
May 30, 2023
by
danyao12
Browse files
fix decoder related Q&dO reading bugs
parent
758d0281
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
18 deletions
+11
-18
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
+10
-17
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
View file @
fd37b2cf
...
@@ -695,7 +695,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -695,7 +695,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
B1BlockTransferDstScalarPerVector_BK1
,
fals
e
,
tru
e
,
B1BlockLdsExtraN
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
View file @
fd37b2cf
...
@@ -1090,23 +1090,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1090,23 +1090,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return
q_grid_desc_m0_k_m1
;
return
q_grid_desc_m0_k_m1
;
}
}
// // C position
// template <typename KGridDesc_K0_N_K1_>
// __device__ static auto MakeKGradGridDesc_N_K(const KGridDesc_K0_N_K1_&
// k_grid_desc_k0_n_k1)
// {
// const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
// const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
// const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
// return transform_tensor_descriptor(
// k_grid_desc_k0_n_k1,
// make_tuple(make_pass_through_transform(N),
// make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
};
};
struct
SharedMemTrait
struct
SharedMemTrait
...
@@ -1380,6 +1363,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1380,6 +1363,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
const
auto
vgrad_gemm_tile_ygrad_block_next_copy_step
=
make_multi_index
(
MPerBlock
/
B1K1
,
0
,
0
);
// dV: blockwise gemm
// dV: blockwise gemm
auto
vgrad_blockwise_gemm
=
auto
vgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
...
@@ -1410,6 +1396,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1410,6 +1396,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
const
auto
kgrad_gemm_tile_q_block_next_copy_step
=
make_multi_index
(
MPerBlock
/
B1K1
,
0
,
0
);
// dK: blockwise gemm
// dK: blockwise gemm
auto
kgrad_blockwise_gemm
=
auto
kgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
...
@@ -1744,6 +1733,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1744,6 +1733,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
Gemm2
::
c_block_slice_copy_step
);
// step M
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
kgrad_gemm_tile_q_block_next_copy_step
);
// step M
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
vgrad_gemm_tile_ygrad_block_next_copy_step
);
// step M
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
));
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
));
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment