Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
86834375
Commit
86834375
authored
May 16, 2022
by
ltqin
Browse files
finish blockwise gemm
parent
0adfd1a4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
29 deletions
+18
-29
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
+13
-24
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+4
-4
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
86834375
...
...
@@ -52,7 +52,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_b_register.hpp
View file @
86834375
...
...
@@ -250,13 +250,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_
block
_buf
,
const
BBlockBuffer
&
b_
thread
_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
//
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
//
b_thread_desc_.GetElementSpaceSize());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
...
...
@@ -269,13 +269,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
...
...
@@ -283,8 +276,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
k
/
KPack
,
0
,
n0
,
0
,
0
,
i
))
>
{}];
});
using
mfma_input_type
=
...
...
@@ -309,7 +303,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
// KPerThread
I1
,
// NBlockId
Number
<
NRepeat
>
{},
// repeat
I1
,
// waves
I1
,
// NPerXdlops
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
...
@@ -325,18 +325,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
}
// namespace ck
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
86834375
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#ifndef CK_GRIDWISE_GEMM_XDLOPS_
SKIP_B_LDS_
V2R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_
SKIP_B_LDS_
V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
...
...
@@ -668,7 +668,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
//
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
,
c_thread_buf
);
block_sync_lds
();
// move windows
...
...
@@ -687,7 +687,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
{
block_sync_lds
();
//
blockwise_gemm.Run(a_block_buf, b_
block
_buf, c_thread_buf);
blockwise_gemm
.
Run
(
a_block_buf
,
b_
thread
_buf
,
c_thread_buf
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment