Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
092f54b4
Commit
092f54b4
authored
Oct 29, 2023
by
Jing Zhang
Browse files
add transpose
parent
5530440b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
240 additions
and
10 deletions
+240
-10
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+238
-8
No files found.
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp
View file @
092f54b4
...
@@ -142,8 +142,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
...
@@ -142,8 +142,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
1
,
2
,
1
,
8
,
1
,
1
,
1
,
1
,
1
,
1
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
092f54b4
...
@@ -8,9 +8,42 @@
...
@@ -8,9 +8,42 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
detail
{
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template
<
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
>
struct
lambda_scalar_per_access_for_src_and_dst
{
__host__
__device__
constexpr
auto
operator
()(
index_t
i
)
const
{
if
(
i
==
SrcVectorDim
&&
i
==
DstVectorDim
)
{
return
math
::
lcm
(
SrcScalarPerVector
,
DstScalarPerVector
);
}
else
if
(
i
==
SrcVectorDim
)
{
return
SrcScalarPerVector
;
}
else
if
(
i
==
DstVectorDim
)
{
return
DstScalarPerVector
;
}
else
{
return
1
;
}
}
};
}
// namespace detail
// Thread-level multi-source, multi-destination tensor slice data movement
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 1. All sources and destinations are DynamicBuffer
...
@@ -70,16 +103,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2
...
@@ -70,16 +103,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
static
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SrcSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
SrcDimAccessOrder
,
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
static
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
static
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SrcSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
SrcDimAccessOrder
,
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>
,
false
>
;
using
DstSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
using
DstSpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DstDimAccessOrder
,
DstDimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>
,
false
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v7r2
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v7r2
(
const
SrcDescs
&
src_descs
,
const
SrcDescs
&
src_descs
,
...
@@ -241,17 +276,114 @@ struct ThreadwiseTensorSliceTransfer_v7r2
...
@@ -241,17 +276,114 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
});
}
}
__device__
void
TransposeFromElmToDst
()
{
using
DstData
=
remove_cvref_t
<
decltype
(
DstDatas
{}[
I0
])
>
;
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
SrcScalarPerVector
,
decltype
(
GetSrcThreadScratchDescriptor
()),
true
>
;
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstScalarPerVector
,
decltype
(
GetDstThreadScratchDescriptor
()),
true
>
;
SrcThreadScratch
elm_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
elm_thread_scratch_
.
data_
=
bit_cast
<
decltype
(
elm_thread_scratch_
.
data_
)
>
(
elm_vectors_tuple_
);
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
data_idx
=
access_idx
*
scalar_per_access
;
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
elm_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
dst_scalar_step_in_vector
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
src_scalar_step_in_vector
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
DstData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}(
[
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
elm_thread_scratch_
[
idx
];
});
}
dst_vectors_tuple_
=
bit_cast
<
decltype
(
dst_vectors_tuple_
)
>
(
dst_thread_scratch_
.
data_
);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
DstBuffers
,
template
<
typename
DstBuffers
,
enable_if_t
<
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
enable_if_t
<
DstDescs
::
Size
()
==
1
&&
DstBuffers
::
Size
()
==
1
,
bool
>
=
false
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
{
dst_vectors_tuple_
=
bit_cast
<
decltype
(
dst_vectors_tuple_
)
>
(
elm_vectors_tuple_
);
TransposeFromElmToDst
(
);
// loop over space-filling curve
// loop over space-filling curve
static_for
<
0
,
dst_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
static_for
<
0
,
dst_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
dst_vectors
=
dst_vectors_tuple_
[
iAccess
];
auto
dst_vectors
=
dst_vectors_tuple_
[
Number
<
iAccess
>
{}
];
// copy data from buf_vectors into dst_bufs
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -336,6 +468,104 @@ struct ThreadwiseTensorSliceTransfer_v7r2
...
@@ -336,6 +468,104 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
}
}
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment