Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7b18e6fd
Commit
7b18e6fd
authored
Sep 14, 2022
by
wangshaojie6
Browse files
attention with lower triangle mask with tile skipping
parent
a614e299
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
104 deletions
+3
-104
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-102
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
7b18e6fd
...
@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
...
@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
256
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideA
=
-
1
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
7b18e6fd
...
@@ -266,39 +266,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -266,39 +266,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
warpSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
warpSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0MNIdx
(
const
index_t
n4
)
{
auto
waveIdx
=
GetGemm0WaveIdx
();
auto
waveMNIdx
=
GetGemm0WaveMNIdx
(
waveIdx
[
I2
]);
auto
MIdx
=
waveIdx
[
I0
]
*
MPerXdl
+
waveMNIdx
[
I1
];
auto
NIdx
=
waveIdx
[
I1
]
*
NPerXdl
+
waveMNIdx
[
I0
]
*
n4
;
return
make_tuple
(
NIdx
,
MIdx
);
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
...
@@ -600,37 +567,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -600,37 +567,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// get m/n id
//const auto wave_id = GetGemm0WaveIdx();
//const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
const
auto
m_n_id
=
GetGemm0MNIdx
(
n4
);
#if 0
if(blockIdx.x == 0)
{
printf("tid=%d, wave mn id=[%d, %d], mn id=[%d, %d]\n",
static_cast<int>(threadIdx.x),
wave_m_n_id[I0],
wave_m_n_id[I1],
m_n_id[I0],
m_n_id[I1]);
}
if(blockIdx.x == 0 && threadIdx.x == 0)
{
printf("%d, %d, %d, %d, %d, %d, %d, %d\n",
static_cast<int>(m0),
static_cast<int>(n0),
static_cast<int>(m1),
static_cast<int>(n1),
static_cast<int>(m2),
static_cast<int>(n2),
static_cast<int>(n3),
static_cast<int>(n4));
}
#endif
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// m0_m1_m2 -> m
...
@@ -659,17 +595,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -659,17 +595,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
A1ThreadSlice_K0_M_K1
,
A1ThreadSlice_K0_M_K1
,
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
#if 0
if(threadIdx.x == 0)
{
printf("bid=%d, A1ThreadSliceK0=%d, A1ThreadSliceM=%d, A1ThreadSliceK1=%d\n",
static_cast<int>(blockIdx.x),
static_cast<int>(A1ThreadSliceK0),
static_cast<int>(A1ThreadSliceM),
static_cast<int>(A1ThreadSliceK1));
}
#endif
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
...
@@ -873,7 +798,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -873,7 +798,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
(
warpSize
/
MPerXdl
)
*
n4
;
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
index_t
n_global
=
nstartgroup
+
n4_i
;
...
@@ -897,32 +822,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -897,32 +822,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
});
});
});
static_for
<
0
,
MXdlPerWave
,
1
>
{}([
&
](
auto
i_m0
){
static_for
<
0
,
NXdlPerWave
,
1
>
{}([
&
](
auto
i_n0
){
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
i_n2
){
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
i_n4
){
auto
global_m_idx
=
m_block_data_idx_on_grid
+
m_n_id
[
I1
]
+
i_m0
*
Gemm0MWaves
*
MPerXdl
;
auto
global_n_idx
=
gemm1_k_block_outer_index
*
NPerBlock
+
m_n_id
[
I0
]
+
i_n0
*
Gemm0NWaves
*
NPerXdl
+
i_n4
+
i_n2
*
n4
*
(
warpSize
/
MPerXdl
);
#if 0
if(blockIdx.x == 0 && i_m0 == 0 && i_n0 == 0)
{
printf("tid=%d, global_mn_idx=[%d, %d]\n",
static_cast<int>(threadIdx.x),
global_m_idx,
global_n_idx);
}
#endif
if
(
global_n_idx
>
global_m_idx
)
{
acc_thread_buf
(
i_m0
*
n0
*
n2
*
n4
+
i_n0
*
n2
*
n4
+
i_n2
*
n4
+
i_n4
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
});
});
});
});
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// softmax
// softmax
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment