Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a614e299
Commit
a614e299
authored
Sep 14, 2022
by
wangshaojie6
Browse files
Merge branch 'att_diagnal' into att_lower_triangle
parents
1ebc21d4
e392ce24
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
1 deletion
+110
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+109
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
a614e299
...
...
@@ -97,6 +97,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
...
...
@@ -262,6 +266,39 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
warpSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
warpSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0MNIdx
(
const
index_t
n4
)
{
auto
waveIdx
=
GetGemm0WaveIdx
();
auto
waveMNIdx
=
GetGemm0WaveMNIdx
(
waveIdx
[
I2
]);
auto
MIdx
=
waveIdx
[
I0
]
*
MPerXdl
+
waveMNIdx
[
I1
];
auto
NIdx
=
waveIdx
[
I1
]
*
NPerXdl
+
waveMNIdx
[
I0
]
*
n4
;
return
make_tuple
(
NIdx
,
MIdx
);
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
...
...
@@ -563,6 +600,37 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// get m/n id
//const auto wave_id = GetGemm0WaveIdx();
//const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
const
auto
m_n_id
=
GetGemm0MNIdx
(
n4
);
#if 0
if(blockIdx.x == 0)
{
printf("tid=%d, wave mn id=[%d, %d], mn id=[%d, %d]\n",
static_cast<int>(threadIdx.x),
wave_m_n_id[I0],
wave_m_n_id[I1],
m_n_id[I0],
m_n_id[I1]);
}
if(blockIdx.x == 0 && threadIdx.x == 0)
{
printf("%d, %d, %d, %d, %d, %d, %d, %d\n",
static_cast<int>(m0),
static_cast<int>(n0),
static_cast<int>(m1),
static_cast<int>(n1),
static_cast<int>(m2),
static_cast<int>(n2),
static_cast<int>(n3),
static_cast<int>(n4));
}
#endif
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
...
...
@@ -591,6 +659,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
A1ThreadSlice_K0_M_K1
,
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
#if 0
if(threadIdx.x == 0)
{
printf("bid=%d, A1ThreadSliceK0=%d, A1ThreadSliceM=%d, A1ThreadSliceK1=%d\n",
static_cast<int>(blockIdx.x),
static_cast<int>(A1ThreadSliceK0),
static_cast<int>(A1ThreadSliceM),
static_cast<int>(A1ThreadSliceK1));
}
#endif
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
...
...
@@ -762,6 +841,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t
gemm1_k_block_outer_index
=
0
;
do
{
if
((
m_block_data_idx_on_grid
<
gemm1_k_block_outer_index
*
NPerBlock
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm1_k_block_outer_index
*
NPerBlock
+
NPerBlock
-
1
)))
{
continue
;
}
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
...
@@ -814,6 +897,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
});
static_for
<
0
,
MXdlPerWave
,
1
>
{}([
&
](
auto
i_m0
){
static_for
<
0
,
NXdlPerWave
,
1
>
{}([
&
](
auto
i_n0
){
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
i_n2
){
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
i_n4
){
auto
global_m_idx
=
m_block_data_idx_on_grid
+
m_n_id
[
I1
]
+
i_m0
*
Gemm0MWaves
*
MPerXdl
;
auto
global_n_idx
=
gemm1_k_block_outer_index
*
NPerBlock
+
m_n_id
[
I0
]
+
i_n0
*
Gemm0NWaves
*
NPerXdl
+
i_n4
+
i_n2
*
n4
*
(
warpSize
/
MPerXdl
);
#if 0
if(blockIdx.x == 0 && i_m0 == 0 && i_n0 == 0)
{
printf("tid=%d, global_mn_idx=[%d, %d]\n",
static_cast<int>(threadIdx.x),
global_m_idx,
global_n_idx);
}
#endif
if
(
global_n_idx
>
global_m_idx
)
{
acc_thread_buf
(
i_m0
*
n0
*
n2
*
n4
+
i_n0
*
n2
*
n4
+
i_n2
*
n4
+
i_n4
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
});
});
});
});
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// softmax
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
a614e299
...
...
@@ -195,7 +195,7 @@ struct ReferenceBatchedGemmUpperTriangleMinusInf : public device::BaseOperator
AccDataType
v_c
;
if
(
n
<=
m
)
if
(
((
n
>>
0
)
<<
0
)
<=
((
m
>>
0
)
<<
0
)
)
{
arg
.
c_element_op_
(
v_c
,
v_acc
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment