Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
437c996a
Commit
437c996a
authored
Apr 21, 2021
by
Chao Liu
Browse files
use StaticBuffer for thread matrix A/B in blockwise GEMM
parent
36de63ff
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
138 deletions
+49
-138
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+2
-5
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+6
-7
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+23
-126
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+18
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
437c996a
...
@@ -545,11 +545,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -545,11 +545,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx_desc_
.
GetElementSpaceSize
()];
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_mtx_desc_
.
GetElementSpaceSize
());
FloatB
p_b_thread
[
b_thread_mtx_desc_
.
GetElementSpaceSize
()];
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_mtx_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
make_dynamic_buffer
(
p_a_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
(
p_b_thread
);
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1r1
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1r1
<
FloatA
,
FloatB
,
FloatB
,
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
437c996a
...
@@ -1379,9 +1379,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1379,9 +1379,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
#if 0 // debug
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
#endif
static_assert
(
is_known_at_compile_time
<
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
...
@@ -1437,13 +1435,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1437,13 +1435,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// position in slice window
#if 0
#if 0 // debug
// TODO: unable to compile
//
TODO: unable to compile
//
position in slice window
constexpr auto data_to_origin_disp_idx =
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
src_scalar_per_access;
#else
#else
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
#endif
#endif
...
@@ -1470,13 +1469,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1470,13 +1469,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc
,
src_data_coord
);
src_desc
,
src_data_coord
);
#if 0
#if 0
// TODO: this is slooooooooow
!
// TODO: this is slooooooooow
due to VGPR over-allocation
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() /
is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() /
SrcScalarPerVector]
SrcScalarPerVector]
: src_vector_t{0};
: src_vector_t{0};
#else
#else
// this has normal performance but it's hacky
//
TODO: this is workaround.
this has normal performance but it's hacky
src_tmp_buf
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_buf
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
(
reinterpret_cast
<
const
SrcData
*>
(
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
(
reinterpret_cast
<
const
SrcData
*>
(
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
437c996a
...
@@ -191,60 +191,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
...
@@ -191,60 +191,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
typename
BOriginIdx
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
CBuffer
,
typename
COriginIdx
>
typename
COriginIdx
>
__device__
static
void
Run_source
(
const
ABuffer
&
a_buf
,
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
BOriginIdx
,
CBuffer
&
c_buf
,
CBuffer
&
c_buf
,
COriginIdx
)
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
a_origin_idx
=
AOriginIdx
{};
constexpr
auto
b_origin_idx
=
BOriginIdx
{};
constexpr
auto
c_origin_idx
=
COriginIdx
{};
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
n
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
n
));
c_buf
.
template
AsType
<
FloatC
>()(
c_offset
)
+=
inner_product_with_conversion
<
FloatC
>
{}(
a_buf
.
template
AsType
<
FloatA
>()[
a_offset
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset
]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run_amd_asm
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
CDesc
::
IsKnownAtCompileTime
(),
...
@@ -258,8 +210,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
...
@@ -258,8 +210,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
...
@@ -269,83 +219,30 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
...
@@ -269,83 +219,30 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
if
constexpr
(
N
==
2
)
constexpr
index_t
a_offset
=
{
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
constexpr
auto
b_offset_0
=
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I0
));
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
n
));
constexpr
auto
b_offset_1
=
constexpr
index_t
c_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I1
));
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
n
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
a_buf
.
template
AsType
<
FloatA
>()[
a_offset
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_0
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_1
],
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_0
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_1
));
}
else
if
constexpr
(
N
==
4
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
a_buf
.
template
AsType
<
FloatA
>()[
a_offset
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_0
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_1
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_2
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_3
],
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_0
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_1
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_2
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_3
));
}
});
});
}
#endif
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm
(
a_buf
,
AOriginIdx
{},
b_buf
,
BOriginIdx
{},
c_buf
,
COriginIdx
{});
amd_assembly_inner_product
(
a_buf
.
template
AsType
<
FloatA
>()[
Number
<
a_offset
>
{}],
b_buf
.
template
AsType
<
FloatB
>()[
Number
<
b_offset
>
{}],
c_buf
.
template
AsType
<
FloatC
>()(
Number
<
c_offset
>
{}));
#else
#else
Run_source
(
a_buf
,
AOriginIdx
{},
b_buf
,
BOriginIdx
{},
c_buf
,
COriginIdx
{});
c_buf
.
template
AsType
<
FloatC
>()(
Number
<
c_offset
>
{})
+=
inner_product_with_conversion
<
FloatC
>
{}(
a_buf
.
template
AsType
<
FloatA
>()[
Number
<
a_offset
>
{}],
b_buf
.
template
AsType
<
FloatB
>()[
Number
<
b_offset
>
{}]);
#endif
#endif
});
});
});
}
}
};
};
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
437c996a
...
@@ -5,6 +5,24 @@
...
@@ -5,6 +5,24 @@
namespace
ck
{
namespace
ck
{
// c += inner_product(a, b)
__device__
void
amd_assembly_inner_product
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
#if CK_USE_AMD_V_FMAC_F32
asm
volatile
(
"
\n
\
v_fmac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
asm
volatile
(
"
\n
\
v_mac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#endif
}
// c0 += inner_product(a, b0)
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment