Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
33d1e0e2
Commit
33d1e0e2
authored
Jun 17, 2019
by
Chao Liu
Browse files
refactoring for miopen
parent
b1cb48a0
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
68 additions
and
462 deletions
+68
-462
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+4
-23
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+27
-26
composable_kernel/include/tensor_operation/threadwise_tensor_slice_copy.hpp
...include/tensor_operation/threadwise_tensor_slice_copy.hpp
+3
-4
composable_kernel/include/utility/Array.hpp
composable_kernel/include/utility/Array.hpp
+1
-1
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+6
-81
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+0
-168
composable_kernel/include/utility/config_amd.hpp.in
composable_kernel/include/utility/config_amd.hpp.in
+3
-0
composable_kernel/include/utility/config_nvidia.hpp.in
composable_kernel/include/utility/config_nvidia.hpp.in
+3
-0
composable_kernel/include/utility/functional.hpp
composable_kernel/include/utility/functional.hpp
+2
-4
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+10
-7
composable_kernel/include/utility/utility.hpp
composable_kernel/include/utility/utility.hpp
+7
-21
composable_kernel/include/utility/vector_type.hpp
composable_kernel/include/utility/vector_type.hpp
+0
-125
driver/include/tensor.hpp
driver/include/tensor.hpp
+2
-2
No files found.
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
33d1e0e2
...
@@ -71,24 +71,7 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -71,24 +71,7 @@ __device__ void threadwise_gemm(MatrixA,
integral_constant
<
bool
,
TransC
>
,
integral_constant
<
bool
,
TransC
>
,
FloatC
*
__restrict__
p_c_thread
)
FloatC
*
__restrict__
p_c_thread
)
{
{
#if 0
static_if
<
TransA
&&
(
!
TransB
)
&&
(
!
TransC
)
>
{}([
&
](
auto
fwd
)
{
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("p_a_thread: %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3]);
printf("p_b_thread: %f %f %f %f\n",
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3]);
}
#endif
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
a_mtx
=
MatrixA
{};
constexpr
auto
a_mtx
=
MatrixA
{};
constexpr
auto
b_mtx
=
MatrixB
{};
constexpr
auto
b_mtx
=
MatrixB
{};
constexpr
auto
c_mtx
=
MatrixC
{};
constexpr
auto
c_mtx
=
MatrixC
{};
...
@@ -111,12 +94,10 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -111,12 +94,10 @@ __device__ void threadwise_gemm(MatrixA,
}
}
}
}
}
}
}
}).
Else
([
&
](
auto
fwd
)
{
else
{
// not implemented
// not implemented
assert
(
f
alse
);
static_
assert
(
f
wd
(
false
),
"wrong! support for this config is not implemented"
);
}
}
);
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
33d1e0e2
...
@@ -5,6 +5,10 @@
...
@@ -5,6 +5,10 @@
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#endif
namespace
ck
{
namespace
ck
{
template
<
class
Float
,
template
<
class
Float
,
...
@@ -32,21 +36,18 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -32,21 +36,18 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>::
value
,
"wrong! map is not valid"
);
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>::
value
,
"wrong! map is not valid"
);
#if 0
// TODO: do more sanity-check here, something like:
// doesn't compile, because merged-tensor reordering is not implemented
// constexpr auto src_strides_in_access_order =
// TODO: implement tensor desc ops for merged-tensor
// SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
constexpr auto src_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
constexpr auto dst_strides_in_access_order =
//
constexpr auto dst_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
//
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
// check src/dst stride on the lowest access dimension
// // check src/dst stride on the lowest access dimension
static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
// static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
(DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
// (DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
"wrong! src/dst stride on the lowest access dimension needs to be 1 for "
// "wrong! src/dst stride on the lowest access dimension needs to be 1 for "
"vectorized read/write");
// "vectorized read/write");
#endif
constexpr
auto
slice_lengths_in_access_order
=
constexpr
auto
slice_lengths_in_access_order
=
SliceLengths
::
ReorderGivenNew2Old
(
DimAccessOrder
{});
SliceLengths
::
ReorderGivenNew2Old
(
DimAccessOrder
{});
...
@@ -64,13 +65,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -64,13 +65,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
using
vector_t
=
typename
vector_type
<
Float
,
DataPerAccess
>::
MemoryType
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerAccess
>::
MemoryType
;
#if 1
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
auto
data_multi_id_in_access_order
=
access_multi_id
;
constexpr
index_t
itmp
=
access_multi_id
.
Back
()
*
DataPerAccess
;
data_multi_id_in_access_order
(
nDim
-
1
)
=
access_multi_id
[
nDim
-
1
]
*
DataPerAccess
;
const
auto
data_multi_id
=
constexpr
auto
data_multi_id_in_access_order
=
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
access_multi_id
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
itmp
>
{});
constexpr
auto
data_multi_id
=
reorder_array_given_old2new
(
sequence2array
(
data_multi_id_in_access_order
),
DimAccessOrder
{});
const
index_t
src_index
=
const
index_t
src_index
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
...
@@ -82,14 +85,12 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -82,14 +85,12 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
});
});
#else
#else
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
constexpr
index_t
itmp
=
access_multi_id
.
Back
()
*
DataPerAccess
;
auto
data_multi_id_in_access_order
=
access_multi_id
;
data_multi_id_in_access_order
(
nDim
-
1
)
=
access_multi_id
[
nDim
-
1
]
*
DataPerAccess
;
constexpr
auto
data_multi_id_in_access_order
=
access_multi_id
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
itmp
>
{});
const
expr
auto
data_multi_id
=
reorder_array_given_old2new
(
const
auto
data_multi_id
=
sequence2array
(
data_multi_id_in_access_order
)
,
DimAccessOrder
{});
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
const
index_t
src_index
=
const
index_t
src_index
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_copy.hpp
View file @
33d1e0e2
...
@@ -56,7 +56,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
...
@@ -56,7 +56,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
static_ford
<
decltype
(
ref_desc
.
GetLengths
().
PopBack
())
>
{}([
=
](
auto
Ids
)
{
static_ford
<
decltype
(
ref_desc
.
GetLengths
().
PopBack
())
>
{}([
=
](
auto
Ids
)
{
static_for
<
0
,
nRead
,
1
>
{}([
&
](
auto
IRead
)
{
static_for
<
0
,
nRead
,
1
>
{}([
&
](
auto
IRead
)
{
constexpr
auto
multi_id
=
decltype
(
Ids
){}.
PushBack
(
Number
<
IRead
.
Get
()
*
DataPerRead
>
{});
constexpr
auto
multi_id
=
decltype
(
Ids
){}.
PushBack
(
Number
<
IRead
*
DataPerRead
>
{});
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
multi_id
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
multi_id
);
...
@@ -177,8 +177,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
...
@@ -177,8 +177,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
// pack data
// pack data
static_for
<
0
,
DstDataPerWrite
,
1
>
{}([
&
](
auto
IDstData
)
{
static_for
<
0
,
DstDataPerWrite
,
1
>
{}([
&
](
auto
IDstData
)
{
const
auto
dst_multi_id
=
const
auto
dst_multi_id
=
ids
.
PushBack
(
IWrite
*
DstDataPerWrite
+
IDstData
);
ids
.
PushBack
(
IWrite
.
Get
()
*
DstDataPerWrite
+
IDstData
.
Get
());
const
auto
src_multi_id
=
reorder_array_given_old2new
(
dst_multi_id
,
MapDst2Src
{});
const
auto
src_multi_id
=
reorder_array_given_old2new
(
dst_multi_id
,
MapDst2Src
{});
...
@@ -189,7 +188,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
...
@@ -189,7 +188,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
});
});
// write data
// write data
const
auto
dst_multi_id
=
ids
.
PushBack
(
IWrite
.
Get
()
*
DstDataPerWrite
);
const
auto
dst_multi_id
=
ids
.
PushBack
(
IWrite
*
DstDataPerWrite
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_multi_id
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_multi_id
);
...
...
composable_kernel/include/utility/Array.hpp
View file @
33d1e0e2
...
@@ -98,7 +98,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
...
@@ -98,7 +98,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
{
{
static_assert
(
NSize
==
sizeof
...(
IRs
),
"NSize not consistent"
);
static_assert
(
NSize
==
sizeof
...(
IRs
),
"NSize not consistent"
);
static_assert
(
is_valid_sequence_map
<
Sequence
<
IRs
...
>>
::
value
,
"wrong! invalid reorder map"
);
static_assert
(
is_valid_sequence_map
<
Sequence
<
IRs
...
>>
{}
,
"wrong! invalid reorder map"
);
return
Array
<
TData
,
NSize
>
{
old_array
[
IRs
]...};
return
Array
<
TData
,
NSize
>
{
old_array
[
IRs
]...};
}
}
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
33d1e0e2
...
@@ -55,22 +55,6 @@ struct Sequence
...
@@ -55,22 +55,6 @@ struct Sequence
return
Sequence
<
Type
::
Get
(
Number
<
IRs
>
{})...
>
{};
return
Sequence
<
Type
::
Get
(
Number
<
IRs
>
{})...
>
{};
}
}
#if 0 // require sequence_sort, which is not implemented yet
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
{
static_assert(sizeof...(Is) == MapOld2New::GetSize(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value,
"wrong! invalid reorder map");
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
return ReorderGivenNew2Old(map_new2old);
}
#endif
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
index_t
Front
()
__host__
__device__
static
constexpr
index_t
Front
()
...
@@ -263,74 +247,15 @@ struct sequence_reverse<Sequence<I0, I1>>
...
@@ -263,74 +247,15 @@ struct sequence_reverse<Sequence<I0, I1>>
using
SeqType
=
Sequence
<
I1
,
I0
>
;
using
SeqType
=
Sequence
<
I1
,
I0
>
;
};
};
#if 0 // not fully implemented
template <class KeySeq0, class ValSeq0, class KeySeq1, class ValSeq1>
struct sequence_sort_merge_impl;
template <index_t Key0,
index_t... Keys0,
index_t Val0,
index_t... Vals0,
index_t Key1,
index_t... Keys1,
index_t Val0,
index_t... Vals1>
struct sequence_sort_merge_impl<Sequence<Key0, Keys0...>,
Sequence<Val0, Vals0...>,
Sequence<Key1, Keys1...>,
Sequence<Val1, Vals1...>>
{
};
template <class>
struct sequence_sort;
template <index_t... Is>
struct sequence_sort<Sequence<Is...>>
{
using OriginalSeqType = Sequence<Is...>;
using SortedSeqType = xxxxx;
using MapSorted2OriginalType = xxx;
};
template <class Seq, class IsValidSeqMap>
struct sequence_map_inverse_impl;
// impl for valid map, no impl for invalid map
template <index_t... Is>
struct sequence_map_inverse_impl<Sequence<Is...>, true>
{
using SeqMapType = sequence_sort<Sequence<Is...>>::MapSorted2OriginalType;
};
template <class>
struct sequence_map_inverse;
template <class Is...>
struct sequence_map_inverse<Sequence<Is...>>
{
// TODO: make sure the map to be inversed is valid: [0, sizeof...(Is))
static constexpr bool is_valid_sequence_map =
is_same<typename sequence_sort<Sequence<Is...>>::SortedSeqType,
typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value;
// make compiler fails, if is_valid_map != true
using SeqMapType =
typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType;
};
#endif
template
<
class
Seq
>
template
<
class
Seq
>
struct
is_valid_sequence_map
struct
is_valid_sequence_map
{
{
static
constexpr
bool
value
=
static
constexpr
bool
value
=
true
;
#if 0 // sequence_sort is not implemented yet
is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
// TODO: add proper check for is_valid, something like:
typename sequence_sort<Seq>::SortedSeqType>::value;
// static constexpr bool value =
#else
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
true
;
// typename sequence_sort<Seq>::SortedSeqType>{};
#endif
};
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
33d1e0e2
...
@@ -3,91 +3,8 @@
...
@@ -3,91 +3,8 @@
#include "vector_type.hpp"
#include "vector_type.hpp"
#define NO_VM_WAIT 0
#define NO_LGKM_WAIT 0
#define NO_DS_READ 0
#define NO_DS_WRITE 0
#define NO_GLB_READ 0
namespace
ck
{
namespace
ck
{
// cast a pointer of LDS to its address
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
__device__
void
vmcnt
(
index_t
cnt
)
{
#if !NO_VM_WAIT
if
(
cnt
==
0
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(0)
\n
\
"
::
);
}
else
if
(
cnt
==
1
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(1)
\n
\
"
::
);
}
else
if
(
cnt
==
2
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(2)
\n
\
"
::
);
}
else
if
(
cnt
==
4
)
{
asm
volatile
(
"
\n
\
s_waitcnt vmcnt(2)
\n
\
"
::
);
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
lgkmcnt
(
index_t
cnt
)
{
#if !NO_LGKM_WAIT
if
(
cnt
==
0
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(0)
\n
\
"
::
);
}
else
if
(
cnt
==
1
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(1)
\n
\
"
::
);
}
else
if
(
cnt
==
2
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(2)
\n
\
"
::
);
}
else
if
(
cnt
==
3
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(3)
\n
\
"
::
);
}
else
if
(
cnt
==
4
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(4)
\n
\
"
::
);
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
outerProduct1x4
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
__device__
void
outerProduct1x4
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -112,21 +29,7 @@ __device__ void outerProduct1x4(const float& a,
...
@@ -112,21 +29,7 @@ __device__ void outerProduct1x4(const float& a,
const
vector_type
<
float
,
4
>::
MemoryType
&
b
,
const
vector_type
<
float
,
4
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
vector_type
<
float
,
4
>::
MemoryType
&
c
)
{
{
#if 0
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c.x),"v"(c.y),"v"(c.z),"v"(c.w), \
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
#else
outerProduct1x4
(
&
a
,
(
float
*
)
&
b
,
(
float
*
)
&
c
);
outerProduct1x4
(
&
a
,
(
float
*
)
&
b
,
(
float
*
)
&
c
);
#endif
}
}
__device__
void
outerProduct4x4
(
const
vector_type
<
float
,
4
>::
MemoryType
&
a
,
__device__
void
outerProduct4x4
(
const
vector_type
<
float
,
4
>::
MemoryType
&
a
,
...
@@ -136,57 +39,10 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
...
@@ -136,57 +39,10 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
{
{
#if 0
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c0.x),"v"(c0.y),"v"(c0.z),"v"(c0.w), \
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c1.x),"v"(c1.y),"v"(c1.z),"v"(c1.w), \
"v"(a.y),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c2.x),"v"(c2.y),"v"(c2.z),"v"(c2.w), \
"v"(a.z),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c3.x),"v"(c3.y),"v"(c3.z),"v"(c3.w), \
"v"(a.w),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
#else
outerProduct1x4
(
a
.
x
,
b
,
c0
);
outerProduct1x4
(
a
.
x
,
b
,
c0
);
outerProduct1x4
(
a
.
y
,
b
,
c1
);
outerProduct1x4
(
a
.
y
,
b
,
c1
);
outerProduct1x4
(
a
.
z
,
b
,
c2
);
outerProduct1x4
(
a
.
z
,
b
,
c2
);
outerProduct1x4
(
a
.
w
,
b
,
c3
);
outerProduct1x4
(
a
.
w
,
b
,
c3
);
#endif
}
}
__device__
void
outerProduct8x8
(
const
vector_type
<
float
,
4
>::
MemoryType
*
a
,
__device__
void
outerProduct8x8
(
const
vector_type
<
float
,
4
>::
MemoryType
*
a
,
...
@@ -201,7 +57,6 @@ __device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
...
@@ -201,7 +57,6 @@ __device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
__device__
void
ds_read_b128
(
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
__device__
void
ds_read_b128
(
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
{
{
#if !NO_DS_READ
if
(
offset
==
0
)
if
(
offset
==
0
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -722,33 +577,11 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -722,33 +577,11 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
#endif
}
__device__
void
global_load
(
vector_type
<
float
,
4
>::
MemoryType
&
r
,
const
vector_type
<
float
,
4
>::
MemoryType
*
ptr
,
index_t
offset
=
0
)
{
#if !NO_GLB_READ
if
(
offset
==
0
)
{
asm
volatile
(
"
\n
\
global_load_dwordx4 %0, %1, off
\n
\
"
:
"=v"
(
r
)
:
"v"
(
ptr
));
}
else
{
assert
(
false
);
}
#endif
}
}
__device__
void
__device__
void
ds_write_b128
(
const
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
ds_write_b128
(
const
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
{
{
#if !NO_DS_WRITE
if
(
offset
==
0
)
if
(
offset
==
0
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -761,7 +594,6 @@ ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t off
...
@@ -761,7 +594,6 @@ ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t off
{
{
assert
(
false
);
assert
(
false
);
}
}
#endif
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/config_amd.hpp.in
View file @
33d1e0e2
...
@@ -7,6 +7,9 @@
...
@@ -7,6 +7,9 @@
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
namespace ck {
namespace ck {
// For some reason, HIP compiler need this definition to generate optimal load and store
// For some reason, HIP compiler need this definition to generate optimal load and store
...
...
composable_kernel/include/utility/config_nvidia.hpp.in
View file @
33d1e0e2
...
@@ -9,6 +9,9 @@
...
@@ -9,6 +9,9 @@
#include "helper_cuda.h"
#include "helper_cuda.h"
#define CK_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_INLINE_ASM 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
namespace ck {
namespace ck {
// For some reason, CUDA need this definition, otherwise
// For some reason, CUDA need this definition, otherwise
...
...
composable_kernel/include/utility/functional.hpp
View file @
33d1e0e2
...
@@ -24,10 +24,8 @@ struct swallow
...
@@ -24,10 +24,8 @@ struct swallow
};
};
// Emulate if constexpr
// Emulate if constexpr
template
<
bool
Predicate
>
template
<
bool
>
struct
static_if
struct
static_if
;
{
};
template
<
>
template
<
>
struct
static_if
<
true
>
struct
static_if
<
true
>
...
...
composable_kernel/include/utility/integral_constant.hpp
View file @
33d1e0e2
#ifndef CK_INTEGRAL_CONSTANT_HPP
#ifndef CK_INTEGRAL_CONSTANT_HPP
#define CK_INTEGRAL_CONSTANT_HPP
#define CK_INTEGRAL_CONSTANT_HPP
namespace
ck
{
#include <type_traits>
template
<
class
T
,
T
N
>
namespace
ck
{
struct
integral_constant
{
static
const
T
value
=
N
;
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
template
<
class
T
,
T
v
>
}
;
using
integral_constant
=
std
::
integral_constant
<
T
,
v
>
;
template
<
class
T
,
T
X
,
T
Y
>
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
...
@@ -17,6 +14,12 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
...
@@ -17,6 +14,12 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
return
integral_constant
<
T
,
X
+
Y
>
{};
return
integral_constant
<
T
,
X
+
Y
>
{};
}
}
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
*
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
{
return
integral_constant
<
T
,
X
*
Y
>
{};
}
template
<
index_t
N
>
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
using
Number
=
integral_constant
<
index_t
,
N
>
;
...
...
composable_kernel/include/utility/utility.hpp
View file @
33d1e0e2
#ifndef CK_UTILITY_HPP
#ifndef CK_UTILITY_HPP
#define CK_UTILITY_HPP
#define CK_UTILITY_HPP
#include <type_traits>
#include "config.hpp"
#include "config.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -9,23 +10,8 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
...
@@ -9,23 +10,8 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
template
<
class
T1
,
class
T2
>
struct
is_same
{
static
constexpr
bool
value
=
false
;
};
template
<
class
T
>
struct
is_same
<
T
,
T
>
{
static
constexpr
bool
value
=
true
;
};
template
<
class
X
,
class
Y
>
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
bool
is_same_type
(
X
,
Y
)
using
is_same
=
std
::
is_same
<
X
,
Y
>
;
{
return
is_same
<
X
,
Y
>::
value
;
}
namespace
math
{
namespace
math
{
...
@@ -58,7 +44,7 @@ struct integer_divide_ceiler
...
@@ -58,7 +44,7 @@ struct integer_divide_ceiler
{
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
index_t
>
::
value
||
is_same
<
T
,
int
>
::
value
,
"wrong type"
);
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{}
,
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
...
@@ -67,7 +53,7 @@ struct integer_divide_ceiler
...
@@ -67,7 +53,7 @@ struct integer_divide_ceiler
template
<
class
T
>
template
<
class
T
>
__host__
__device__
constexpr
T
integer_divide_ceil
(
T
a
,
T
b
)
__host__
__device__
constexpr
T
integer_divide_ceil
(
T
a
,
T
b
)
{
{
static_assert
(
is_same
<
T
,
index_t
>
::
value
||
is_same
<
T
,
int
>
::
value
,
"wrong type"
);
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{}
,
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
...
@@ -85,7 +71,7 @@ __host__ __device__ constexpr T max(T x, Ts... xs)
...
@@ -85,7 +71,7 @@ __host__ __device__ constexpr T max(T x, Ts... xs)
auto
y
=
max
(
xs
...);
auto
y
=
max
(
xs
...);
static_assert
(
is_same
<
decltype
(
y
),
T
>
::
value
,
"not the same type"
);
static_assert
(
is_same
<
decltype
(
y
),
T
>
{}
,
"not the same type"
);
return
x
>
y
?
x
:
y
;
return
x
>
y
?
x
:
y
;
}
}
...
@@ -103,12 +89,12 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
...
@@ -103,12 +89,12 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
auto
y
=
min
(
xs
...);
auto
y
=
min
(
xs
...);
static_assert
(
is_same
<
decltype
(
y
),
T
>
::
value
,
"not the same type"
);
static_assert
(
is_same
<
decltype
(
y
),
T
>
{}
,
"not the same type"
);
return
x
<
y
?
x
:
y
;
return
x
<
y
?
x
:
y
;
}
}
// this is
wrong
// this is
WRONG
// TODO: implement least common multiple properly, instead of calling max()
// TODO: implement least common multiple properly, instead of calling max()
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
__host__
__device__
constexpr
T
lcm
(
T
x
,
Ts
...
xs
)
__host__
__device__
constexpr
T
lcm
(
T
x
,
Ts
...
xs
)
...
...
composable_kernel/include/utility/vector_type.hpp
View file @
33d1e0e2
...
@@ -64,131 +64,6 @@ struct vector_type<float, 4>
...
@@ -64,131 +64,6 @@ struct vector_type<float, 4>
}
}
};
};
#if 0
template <>
struct vector_type<half, 1>
{
using MemoryType = half;
__host__ __device__ static MemoryType Pack(half s) { return s; }
};
template <>
struct vector_type<half, 2>
{
using MemoryType = half2;
__host__ __device__ static MemoryType Pack(half s0, half s1)
{
union
{
MemoryType vector;
half scalar[2];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<half, 4>
{
using MemoryType = float2;
};
template <>
struct vector_type<half, 8>
{
using MemoryType = float4;
};
template <>
struct vector_type<char, 1>
{
using MemoryType = char;
__host__ __device__ static MemoryType Pack(char s) { return s; }
};
template <>
struct vector_type<char, 2>
{
using MemoryType = int16_t;
__host__ __device__ static MemoryType Pack(char s0, char s1)
{
union
{
MemoryType vector;
char scalar[2];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<char, 4>
{
using MemoryType = int32_t;
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
{
union
{
MemoryType vector;
char scalar[4];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
data.scalar[2] = s2;
data.scalar[3] = s3;
return data.vector;
}
};
template <>
struct vector_type<char, 8>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<int32_t, 2>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<char2, 2>
{
using MemoryType = char4;
};
template <>
struct vector_type<char2, 4>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<char4, 1>
{
using MemoryType = int;
};
template <>
struct vector_type<char4, 2>
{
using MemoryType = int64_t;
};
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
driver/include/tensor.hpp
View file @
33d1e0e2
...
@@ -46,7 +46,7 @@ auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
...
@@ -46,7 +46,7 @@ auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
template
<
class
F
,
class
T
>
template
<
class
F
,
class
T
>
auto
call_f_unpack_args
(
F
f
,
T
args
)
auto
call_f_unpack_args
(
F
f
,
T
args
)
{
{
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>
::
value
;
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>
{}
;
return
call_f_unpack_args_impl
(
f
,
args
,
std
::
make_index_sequence
<
N
>
{});
return
call_f_unpack_args_impl
(
f
,
args
,
std
::
make_index_sequence
<
N
>
{});
}
}
...
@@ -60,7 +60,7 @@ auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
...
@@ -60,7 +60,7 @@ auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
template
<
class
F
,
class
T
>
template
<
class
F
,
class
T
>
auto
construct_f_unpack_args
(
F
,
T
args
)
auto
construct_f_unpack_args
(
F
,
T
args
)
{
{
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>
::
value
;
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>
{}
;
return
construct_f_unpack_args_impl
<
F
>
(
args
,
std
::
make_index_sequence
<
N
>
{});
return
construct_f_unpack_args_impl
<
F
>
(
args
,
std
::
make_index_sequence
<
N
>
{});
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment