Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b1dd76f3
"example/vscode:/vscode.git/clone" did not exist on "75af5450827cfdc3e1fd10246fa4ce5f337a7ace"
Commit
b1dd76f3
authored
Jun 16, 2022
by
ltqin
Browse files
regular code
parent
e8a71150
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
213 deletions
+117
-213
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
+4
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+33
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
+73
-197
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+7
-0
No files found.
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
View file @
b1dd76f3
...
@@ -118,11 +118,11 @@ int main(int argc, char* argv[])
...
@@ -118,11 +118,11 @@ int main(int argc, char* argv[])
#else
#else
ck
::
index_t
M
=
16
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
32
;
ck
::
index_t
K
=
8
;
ck
::
index_t
StrideA
=
8
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
8
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideC
=
16
;
ck
::
index_t
StrideC
=
N
;
#endif
#endif
if
(
argc
==
4
)
if
(
argc
==
4
)
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
View file @
b1dd76f3
...
@@ -241,16 +241,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
...
@@ -241,16 +241,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
}
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
template
<
typename
ABlockBuffer
,
typename
AThreadBuffer
>
__device__
void
ReadAThreadData
(
const
ABlockBuffer
&
a_block_buf
,
AThreadBuffer
&
a_thread_buf
)
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
...
@@ -258,8 +251,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
...
@@ -258,8 +251,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
(
Number
<
m0
>
{}));
});
}
__host__
__device__
static
auto
AlloCAThreadBuff
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
a_thread_desc_
.
GetElementSpaceSize
(),
true
>
{};
},
Number
<
MRepeat
>
{});
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_thread_buf
,
const
BBlockBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
// auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// a_thread_desc_.GetElementSpaceSize());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
// read B
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
...
@@ -267,8 +287,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
...
@@ -267,8 +287,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
constexpr
index_t
k0
=
k
/
KPack
;
constexpr
index_t
k0
=
k
/
KPack
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
a_thread_buf
[
Number
<
m0
>
{}][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
n0
,
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
n0
,
i
))
>
{}];
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
View file @
b1dd76f3
...
@@ -449,7 +449,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -449,7 +449,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
{};
true
>
{};
},
},
Number
<
8
>
{});
Number
<
BaseMultK0
>
{});
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
...
@@ -516,6 +516,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -516,6 +516,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
NXdlPerWave
,
NXdlPerWave
,
K1
>
{};
K1
>
{};
auto
a_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
blockwise_gemm
.
AlloCAThreadBuff
();
},
Number
<
BaseMultK0
/
2
>
{});
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A
// LDS allocation for A
...
@@ -531,38 +538,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -531,38 +538,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
auto
read_b_first_half_data
=
[
&
]()
{
static_for
<
0
,
MultiK0
/
2
,
1
>
{}([
&
](
auto
ii
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
0
>
{}));
b_thread_buf
(
Number
<
ii
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
});
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
};
b_grid_buf
,
auto
read_b_last_half_data
=
[
&
]()
{
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
static_for
<
MultiK0
/
2
,
MultiK0
,
1
>
{}([
&
](
auto
ii
)
{
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
1
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
2
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
3
>
{}));
b_thread_buf
(
Number
<
ii
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
});
};
auto
read_a_lds_data
=
[
&
]()
{
static_for
<
0
,
MultiK0
/
2
,
1
>
{}([
&
](
auto
ii
)
{
blockwise_gemm
.
ReadAThreadData
(
a_block_buf
,
a_thread_buf
(
Number
<
ii
>
{}));
blockwise_gemm
.
MoveABlockSliceWindow
();
});
};
read_b_first_half_data
();
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
// a data write to lds
// a data write to lds
...
@@ -580,91 +584,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -580,91 +584,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
)
{
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
read_a_lds_data
();
b_grid_buf
,
read_b_last_half_data
();
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
4
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
5
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
6
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
7
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
s_nop
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
0
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
1
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
2
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
3
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
s_barrier
();
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
0
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
1
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
static_for
<
0
,
MultiK0
/
2
,
1
>
{}([
&
](
auto
ii
)
{
b_grid_buf
,
blockwise_gemm
.
Run
(
a_thread_buf
(
Number
<
ii
>
{}),
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_buf
(
Number
<
ii
>
{}),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
);
b_thread_buf
(
Number
<
2
>
{}));
});
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
3
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
s_nop
();
read_a_lds_data
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
4
>
{}),
c_thread_buf
);
read_b_first_half_data
();
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
5
>
{}),
c_thread_buf
);
s_barrier
();
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
6
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
7
>
{}),
c_thread_buf
);
static_for
<
MultiK0
/
2
,
MultiK0
,
1
>
{}([
&
](
auto
ii
)
{
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_thread_buf
(
Number
<
ii
-
4
>
{}),
b_thread_buf
(
Number
<
ii
>
{}),
c_thread_buf
);
});
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -683,94 +623,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -683,94 +623,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
blockwise_gemm
.
ResetABlockStartWindow
();
blockwise_gemm
.
ResetABlockStartWindow
();
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
i
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
read_a_lds_data
();
b_grid_buf
,
read_b_last_half_data
();
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
4
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
5
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
s_barrier
();
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
6
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
7
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
s_nop
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
0
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
1
>
{}),
c_thread_buf
);
static_for
<
0
,
MultiK0
/
2
,
1
>
{}([
&
](
auto
ii
)
{
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_thread_buf
(
Number
<
ii
>
{}),
b_thread_buf
(
Number
<
ii
>
{}),
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
2
>
{}),
c_thread_buf
);
});
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
3
>
{}),
c_thread_buf
);
read_a_lds_data
();
blockwise_gemm
.
MoveABlockSliceWindow
();
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
{
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
read_b_first_half_data
();
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
0
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
1
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
2
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
(
Number
<
3
>
{}));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
}
}
s_nop
();
s_barrier
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
4
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
5
>
{}),
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
6
>
{}),
c_thread_buf
);
static_for
<
MultiK0
/
2
,
MultiK0
,
1
>
{}([
&
](
auto
ii
)
{
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_thread_buf
(
Number
<
ii
-
4
>
{}),
b_thread_buf
(
Number
<
ii
>
{}),
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
(
Number
<
7
>
{}),
c_thread_buf
);
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
);
}
);
});
});
}
}
}
}
...
...
include/ck/utility/synchronization.hpp
View file @
b1dd76f3
...
@@ -23,5 +23,12 @@ __device__ void s_nop()
...
@@ -23,5 +23,12 @@ __device__ void s_nop()
"
::
);
"
::
);
}
}
__device__
void
s_barrier
()
{
asm
volatile
(
"\
s_barrier \
"
::
);
}
}
// namespace ck
}
// namespace ck
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment