Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5cfe67b1
Commit
5cfe67b1
authored
Aug 24, 2023
by
aska-0096
Browse files
NewBlkGEMM
parent
cc0ffeb7
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
285 additions
and
109 deletions
+285
-109
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+28
-10
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+142
-84
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+62
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+36
-14
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+17
-1
No files found.
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
5cfe67b1
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
...
@@ -83,30 +84,30 @@ using DeviceOpInstance =
...
@@ -83,30 +84,30 @@ using DeviceOpInstance =
1
,
1
,
128
,
128
,
64
,
64
,
128
,
64
,
64
,
64
,
8
,
4
,
16
,
16
,
16
,
16
,
1
,
2
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
8
,
4
,
8
,
tru
e
,
fals
e
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
8
,
4
,
8
,
tru
e
,
fals
e
,
1
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -208,12 +209,29 @@ int main(int argc, char* argv[])
...
@@ -208,12 +209,29 @@ int main(int argc, char* argv[])
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
5
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
5
,
5
});
break
;
break
;
case
2
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
DDataType
>
{
1.
f
,
1.
f
}(
d_m_n
);
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
-
0.5
,
0.5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
-
0.5
,
0.5
});
}
}
#if 0
for(int im = 0; im<M; im++)
{
for(int ik = 0; ik<K; ik++)
{
if(ik%8==0) printf("|");
printf("%4x ", *(reinterpret_cast<uint16_t*>(&(a_m_k(im,ik)))));
}
printf("\n");
}
#endif
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
5cfe67b1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5cfe67b1
...
@@ -448,6 +448,68 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
...
@@ -448,6 +448,68 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
};
template
<
typename
InputDataType
,
index_t
RegPackNumber
>
struct
InterRowPermuter
{
};
template
<
>
struct
InterRowPermuter
<
ck
::
half_t
,
2
>
{
using
InputArray
=
vector_type
<
ck
::
half_t
,
2
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
2
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
OutputArray
Output
;
uint32_t
*
output_half_2
=
reinterpret_cast
<
uint32_t
*>
(
&
Output
);
uint32_t
const
input_half_2
=
reinterpret_cast
<
uint32_t
const
&>
(
Input
);
output_half_2
[
0
]
=
__builtin_amdgcn_permlanex16
(
output_half_2
[
0
],
input_half_2
,
0x76543210
,
0xfedcba98
,
1
,
0
);
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("After permlanex, input: %04x, output: %04x\n", input_half_2, output_half_2);
}
#endif
return
Output
;
}
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
template
<
index_t
N
>
struct
InterRowPermuter
<
ck
::
half_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
2
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 2."
);
using
InputArray
=
vector_type
<
ck
::
half_t
,
N
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
N
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
InterRowPermuter
<
ck
::
half_t
,
2
>
converter
;
OutputArray
Output
;
using
Vec_InputArray
=
vector_type
<
ck
::
half_t
,
2
>
;
using
Vec_OutputArray
=
vector_type
<
ck
::
half_t
,
2
>
;
Vec_OutputArray
*
output_half_2_ptr
=
reinterpret_cast
<
Vec_OutputArray
*>
(
&
Output
);
Vec_InputArray
const
*
input_half_2_ptr
=
reinterpret_cast
<
Vec_InputArray
const
*>
(
&
Input
);
static_for
<
0
,
N
/
VEC_WIDTH
,
1
>
{}(
[
&
](
auto
i
)
{
output_half_2_ptr
[
i
]
=
converter
(
input_half_2_ptr
[
i
]);
});
return
Output
;
}
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
5cfe67b1
...
@@ -498,15 +498,23 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -498,15 +498,23 @@ struct GridwiseGemmMultipleD_Wmma
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
// Debug this part
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
2
;
constexpr
auto
A_KRow
=
I1
;
constexpr
auto
A_K0PerRow
=
ABlockDesc_
{}.
GetLength
(
I0
)
/
A_KRow
;
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
// return make_naive_tensor_descriptor_packed(make_tuple(Number<A_K0PerRow>{},
// Number<MRepeat>{},
// Number<MWaves>{},
// Number<A_KRow>{},
// Number<MPerWmma>{},
// Number<A_K1>{}));
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0PerRow
>
{},
Number
<
A_KRow
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_unmerge_transform
(
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
...
@@ -537,17 +545,31 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -537,17 +545,31 @@ struct GridwiseGemmMultipleD_Wmma
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
#if 1
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_KRow
=
I1
;
constexpr
auto
B_K0PerRow
=
BBlockDesc_
{}.
GetLength
(
I0
)
/
B_KRow
;
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0PerRow
>
{},
Number
<
B_KRow
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_unmerge_transform
(
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
#endif
#if 0
constexpr auto B_KRow = 2;
constexpr auto B_K0PerRow = BBlockDesc_{}.GetLength(I0) / B_KRow;
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return make_naive_tensor_descriptor_packed(make_tuple(Number<B_K0PerRow>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<B_KRow>{},
Number<NPerWmma>{},
Number<B_K1>{}));
#endif
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5cfe67b1
...
@@ -1136,7 +1136,17 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1136,7 +1136,17 @@ struct ThreadwiseTensorSliceTransfer_v4
auto
src_data_coord
=
src_ref_coord_
;
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
#if 0
printf("Tid: %03d, Inele_Offset: %d, Coord: (%d, %d, %d, %d, %d, %d)\n",
get_thread_local_1d_id(),
src_data_coord.GetOffset(),
src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}),
src_data_coord.GetIndex().At(Number<4>{}),
src_data_coord.GetIndex().At(Number<5>{}));
#endif
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
...
@@ -1178,6 +1188,12 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1178,6 +1188,12 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
});
#if 0
printf("Tid: %03d, Inele_Offset: %d\n",
get_thread_local_1d_id(),
dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx));
#endif
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment