Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
31b40352
Unverified
Commit
31b40352
authored
Aug 18, 2021
by
Chao Liu
Committed by
GitHub
Aug 18, 2021
Browse files
Merge pull request #16 from ROCmSoftwarePlatform/develop
Merge develop into master
parents
5781adf5
b62bf8c3
Changes
145
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
932 additions
and
432 deletions
+932
-432
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+10
-10
composable_kernel/include/utility/amd_address_space.hpp
composable_kernel/include/utility/amd_address_space.hpp
+44
-0
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+68
-41
composable_kernel/include/utility/amd_dlop.hpp
composable_kernel/include/utility/amd_dlop.hpp
+0
-188
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+21
-18
composable_kernel/include/utility/c_style_pointer_cast.hpp
composable_kernel/include/utility/c_style_pointer_cast.hpp
+22
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+9
-8
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+17
-25
composable_kernel/include/utility/data_type_enum.hpp
composable_kernel/include/utility/data_type_enum.hpp
+2
-3
composable_kernel/include/utility/data_type_enum_helper.hpp
composable_kernel/include/utility/data_type_enum_helper.hpp
+2
-2
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+86
-48
composable_kernel/include/utility/enable_if.hpp
composable_kernel/include/utility/enable_if.hpp
+13
-0
composable_kernel/include/utility/inner_product.hpp
composable_kernel/include/utility/inner_product.hpp
+207
-0
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+3
-12
composable_kernel/include/utility/print.hpp
composable_kernel/include/utility/print.hpp
+0
-48
composable_kernel/include/utility/sequence.hpp
composable_kernel/include/utility/sequence.hpp
+0
-2
composable_kernel/include/utility/static_buffer.hpp
composable_kernel/include/utility/static_buffer.hpp
+41
-5
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+14
-15
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+3
-7
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
...ution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
+370
-0
No files found.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
31b40352
...
...
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
...
...
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
...
...
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
...
...
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
...
...
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a =
re
inter
pret
_cast<const ushort2_t*>(a);
const auto p_b =
re
inter
pret
_cast<const ushort2_t*>(b);
const auto p_a =
c_style_po
inter_cast<const ushort2_t*>(a);
const auto p_b =
c_style_po
inter_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
...
...
composable_kernel/include/utility/amd_address_space.hpp
0 → 100644
View file @
31b40352
#ifndef CK_AMD_ADDRESS_SPACE_HPP
#define CK_AMD_ADDRESS_SPACE_HPP
#include "config.hpp"
#include "c_style_pointer_cast.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace
ck
{
enum
AddressSpaceEnum_t
{
Generic
,
Global
,
Lds
,
Sgpr
,
Vgpr
,
};
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CONSTANT
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CONSTANT
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CONSTANT
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck
#endif
composable_kernel/include/utility/amd_buffer_addressing
_v2
.hpp
→
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
31b40352
#ifndef CK_AMD_BUFFER_ADDRESSING_
V2_
HPP
#define CK_AMD_BUFFER_ADDRESSING_
V2_
HPP
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
#define CK_AMD_BUFFER_ADDRESSING_HPP
#include "data_type.hpp"
namespace
ck
{
template
<
typename
T
>
union
BufferResource
_v2
union
BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t
data
;
int32x4_t
content
;
StaticallyIndexedArray
<
T
*
,
2
>
address
;
StaticallyIndexedArray
<
int32_t
,
4
>
range
;
StaticallyIndexedArray
<
int32_t
,
4
>
config
;
};
template
<
typename
T
>
__device__
int32x4_t
make_wave_buffer_resource
(
T
*
p_wave
,
index_t
data
_space_size
)
__device__
int32x4_t
make_wave_buffer_resource
(
T
*
p_wave
,
index_t
element
_space_size
)
{
BufferResource
_v2
<
T
>
wave_buffer_resource
;
BufferResource
<
T
>
wave_buffer_resource
;
// wavewise base address (64 bit)
wave_buffer_resource
.
address
(
Number
<
0
>
{})
=
const_cast
<
remove_cv_t
<
T
>*>
(
p_wave
);
// wavewise range (32 bit)
wave_buffer_resource
.
range
(
Number
<
2
>
{})
=
data
_space_size
*
sizeof
(
T
);
wave_buffer_resource
.
range
(
Number
<
2
>
{})
=
element
_space_size
*
sizeof
(
T
);
// wavewise setting (32 bit)
wave_buffer_resource
.
config
(
Number
<
3
>
{})
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
return
wave_buffer_resource
.
data
;
return
wave_buffer_resource
.
content
;
}
// load
...
...
@@ -204,8 +204,7 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f32"
);
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl_v2
(
int32x4_t
src_wave_buffer_resource
,
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
...
...
@@ -412,7 +411,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_store_impl
_v2
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
...
...
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave t
o
be a wavewise pointer.
// 2) p_src_wave
mus
t be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_
v2
(
const
T
*
p_src_wave
,
index_t
src_thread_
data
_offset
,
bool
src_thread_
data
_valid
,
index_t
src_element_space
)
amd_buffer_load_
invalid_element_return_return_zero
(
const
T
*
p_src_wave
,
index_t
src_thread_
element
_offset
,
bool
src_thread_
element
_valid
,
index_t
src_element_space
_size
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space
);
make_wave_buffer_resource
(
p_src_wave
,
src_element_space
_size
);
index_t
src_thread_addr_offset
=
src_thread_
data
_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_
element
_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_
data
_valid
?
0
:
0x7fffffff
;
uint32_t
src_addr_shift
=
src_thread_
element
_valid
?
0
:
0x7fffffff
;
return
amd_buffer_load_impl
_v2
<
scalar_t
,
vector_size
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#else
vector_t
tmp
=
amd_buffer_load_impl
_v2
<
scalar_t
,
vector_size
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_
data
_valid
?
tmp
:
vector_t
(
0
);
return
src_thread_
element
_valid
?
tmp
:
vector_t
(
0
);
#endif
}
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_invalid_element_return_customized_value
(
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
bool
src_thread_element_valid
,
index_t
src_element_space_size
,
T
customized_value
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
customized_value
);
}
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_store_v2
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
__device__
void
amd_buffer_store
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_
data
_offset
,
const
bool
dst_thread_
data
_valid
,
const
index_t
dst_element_space
)
const
index_t
dst_thread_
element
_offset
,
const
bool
dst_thread_
element
_valid
,
const
index_t
dst_element_space
_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space
);
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space
_size
);
index_t
dst_thread_addr_offset
=
dst_thread_
data
_offset
*
sizeof
(
T
);
index_t
dst_thread_addr_offset
=
dst_thread_
element
_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_
data
_valid
?
0
:
0x7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_
element
_valid
?
0
:
0x7fffffff
;
amd_buffer_store_impl
_v2
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_
data
_valid
)
if
(
dst_thread_
element
_valid
)
{
amd_buffer_store_impl
_v2
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
...
...
composable_kernel/include/utility/amd_dlop.hpp
deleted
100644 → 0
View file @
5781adf5
#ifndef CK_AMD_DLOP_HPP
#define CK_AMD_DLOP_HPP
#include "data_type.hpp"
namespace
ck
{
template
<
typename
TA
,
typename
TB
,
typename
TC
>
__device__
void
amd_inner_product_dlop
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
template
<
>
__device__
void
amd_inner_product_dlop
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm
volatile
(
"
\n
\
v_fmac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
c
+=
a
*
b
;
#endif
}
template
<
>
__device__
void
amd_inner_product_dlop
<
float2_t
,
float2_t
,
float
>
(
const
float2_t
&
a
,
const
float2_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_inner_product_dlop
(
vector_type
<
float
,
2
>
{
a
}.
AsType
<
float
>
()[
I0
],
vector_type
<
float
,
2
>
{
b
}.
AsType
<
float
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
2
>
{
a
}.
AsType
<
float
>
()[
I1
],
vector_type
<
float
,
2
>
{
b
}.
AsType
<
float
>
()[
I1
],
c
);
}
template
<
>
__device__
void
amd_inner_product_dlop
<
float4_t
,
float4_t
,
float
>
(
const
float4_t
&
a
,
const
float4_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I0
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I1
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I1
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I2
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I2
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I3
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I3
],
c
);
}
#if CK_USE_AMD_DLOP
template
<
>
__device__
void
amd_inner_product_dlop
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm
volatile
(
"
\n
\
v_dot2_f32_f16 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot2
(
a
,
b
,
c
,
false
);
#endif
}
template
<
>
__device__
void
amd_inner_product_dlop
<
half4_t
,
half4_t
,
float
>
(
const
half4_t
&
a
,
const
half4_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_inner_product_dlop
(
vector_type
<
half_t
,
4
>
{
a
}.
AsType
<
half2_t
>
()[
I0
],
vector_type
<
half_t
,
4
>
{
b
}.
AsType
<
half2_t
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
half_t
,
4
>
{
a
}.
AsType
<
half2_t
>
()[
I1
],
vector_type
<
half_t
,
4
>
{
b
}.
AsType
<
half2_t
>
()[
I1
],
c
);
}
template
<
>
__device__
void
amd_inner_product_dlop
<
half8_t
,
half8_t
,
float
>
(
const
half8_t
&
a
,
const
half8_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_inner_product_dlop
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I0
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I1
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I1
],
c
);
amd_inner_product_dlop
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I2
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I2
],
c
);
amd_inner_product_dlop
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I3
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I3
],
c
);
}
template
<
>
__device__
void
amd_inner_product_dlop
<
int8x4_t
,
int8x4_t
,
int32_t
>
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b
)),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b
),
c
,
false
);
#endif
}
template
<
>
__device__
void
amd_inner_product_dlop
<
int8x8_t
,
int8x8_t
,
int32_t
>
(
const
int8x8_t
&
a
,
const
int8x8_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_inner_product_dlop
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
}
template
<
>
__device__
void
amd_inner_product_dlop
<
int8x16_t
,
int8x16_t
,
int32_t
>
(
const
int8x16_t
&
a
,
const
int8x16_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_inner_product_dlop
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
amd_inner_product_dlop
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I2
],
c
);
amd_inner_product_dlop
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I3
],
c
);
}
#endif // CK_USE_AMD_DLOP
}
// namespace ck
#endif
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
31b40352
...
...
@@ -2,6 +2,9 @@
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace
ck
{
...
...
@@ -53,9 +56,9 @@ __device__ void
amd_assembly_outer_product_1x2
(
half4_t
a
,
half4_t
b0
,
half4_t
b1
,
float
&
c0
,
float
&
c1
)
{
// TODO remove pointer casting
const
half2_t
*
p_a_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
a
);
const
half2_t
*
p_b0_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
b0
);
const
half2_t
*
p_b1_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
b1
);
const
half2_t
*
p_a_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
a
);
const
half2_t
*
p_b0_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
b0
);
const
half2_t
*
p_b1_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
b1
);
// do dot2 two times
asm
volatile
(
"
\n
\
...
...
@@ -114,11 +117,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float
&
c3
)
{
// TODO remove pointer casting
const
half2_t
*
p_a_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
a
);
const
half2_t
*
p_b0_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
b0
);
const
half2_t
*
p_b1_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
b1
);
const
half2_t
*
p_b2_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
b2
);
const
half2_t
*
p_b3_half2
=
re
inter
pret
_cast
<
const
half2_t
*>
(
&
b3
);
const
half2_t
*
p_a_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
a
);
const
half2_t
*
p_b0_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
b0
);
const
half2_t
*
p_b1_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
b1
);
const
half2_t
*
p_b2_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
b2
);
const
half2_t
*
p_b3_half2
=
c_style_po
inter_cast
<
const
half2_t
*>
(
&
b3
);
// do dot2 two times
asm
volatile
(
"
\n
\
...
...
@@ -160,11 +163,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
{
// TODO remove pointer casting
const
half4_t
*
p_a_half4
=
re
inter
pret
_cast
<
const
half4_t
*>
(
&
a
);
const
half4_t
*
p_b0_half4
=
re
inter
pret
_cast
<
const
half4_t
*>
(
&
b0
);
const
half4_t
*
p_b1_half4
=
re
inter
pret
_cast
<
const
half4_t
*>
(
&
b1
);
const
half4_t
*
p_b2_half4
=
re
inter
pret
_cast
<
const
half4_t
*>
(
&
b2
);
const
half4_t
*
p_b3_half4
=
re
inter
pret
_cast
<
const
half4_t
*>
(
&
b3
);
const
half4_t
*
p_a_half4
=
c_style_po
inter_cast
<
const
half4_t
*>
(
&
a
);
const
half4_t
*
p_b0_half4
=
c_style_po
inter_cast
<
const
half4_t
*>
(
&
b0
);
const
half4_t
*
p_b1_half4
=
c_style_po
inter_cast
<
const
half4_t
*>
(
&
b1
);
const
half4_t
*
p_b2_half4
=
c_style_po
inter_cast
<
const
half4_t
*>
(
&
b2
);
const
half4_t
*
p_b3_half4
=
c_style_po
inter_cast
<
const
half4_t
*>
(
&
b3
);
amd_assembly_outer_product_1x4
(
p_a_half4
[
0
],
p_b0_half4
[
0
],
p_b1_half4
[
0
],
p_b2_half4
[
0
],
p_b3_half4
[
0
],
c0
,
c1
,
c2
,
c3
);
...
...
@@ -184,11 +187,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float
&
c3
)
{
// TODO remove pointer casting
const
half8_t
*
p_a_half8
=
re
inter
pret
_cast
<
const
half8_t
*>
(
&
a
);
const
half8_t
*
p_b0_half8
=
re
inter
pret
_cast
<
const
half8_t
*>
(
&
b0
);
const
half8_t
*
p_b1_half8
=
re
inter
pret
_cast
<
const
half8_t
*>
(
&
b1
);
const
half8_t
*
p_b2_half8
=
re
inter
pret
_cast
<
const
half8_t
*>
(
&
b2
);
const
half8_t
*
p_b3_half8
=
re
inter
pret
_cast
<
const
half8_t
*>
(
&
b3
);
const
half8_t
*
p_a_half8
=
c_style_po
inter_cast
<
const
half8_t
*>
(
&
a
);
const
half8_t
*
p_b0_half8
=
c_style_po
inter_cast
<
const
half8_t
*>
(
&
b0
);
const
half8_t
*
p_b1_half8
=
c_style_po
inter_cast
<
const
half8_t
*>
(
&
b1
);
const
half8_t
*
p_b2_half8
=
c_style_po
inter_cast
<
const
half8_t
*>
(
&
b2
);
const
half8_t
*
p_b3_half8
=
c_style_po
inter_cast
<
const
half8_t
*>
(
&
b3
);
amd_assembly_outer_product_1x4
(
p_a_half8
[
0
],
p_b0_half8
[
0
],
p_b1_half8
[
0
],
p_b2_half8
[
0
],
p_b3_half8
[
0
],
c0
,
c1
,
c2
,
c3
);
...
...
composable_kernel/include/utility/c_style_pointer_cast.hpp
0 → 100644
View file @
31b40352
#ifndef CK_C_STYLE_POINTER_CAST_HPP
#define CK_C_STYLE_POINTER_CAST_HPP
#include "type.hpp"
#include "enable_if.hpp"
namespace
ck
{
template
<
typename
PY
,
typename
PX
,
typename
enable_if
<
is_pointer_v
<
PY
>
&&
is_pointer_v
<
PX
>
,
bool
>::
type
=
false
>
__host__
__device__
PY
c_style_pointer_cast
(
PX
p_x
)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return
(
PY
)
p_x
;
// NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
31b40352
...
...
@@ -7,13 +7,14 @@
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
#include "multi_index.hpp"
#include "data_type_enum.hpp"
#include "data_type.hpp"
#include "data_type_helper.hpp"
#include "data_type_enum.hpp"
#include "data_type_enum_helper.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp"
#include "math.hpp"
#include "number.hpp"
...
...
@@ -23,21 +24,21 @@
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "type.hpp"
#include "utility.hpp"
#include "magic_division.hpp"
#include "amd_buffer_addressing_v2.hpp"
#include "utility.hpp"
#include "c_style_pointer_cast.hpp"
#include "amd_address_space.hpp"
#include "amd_buffer_addressing.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "inner_product.hpp"
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#endif
#if CK_USE_AMD_DLOP
#include "amd_dlop.hpp"
#endif
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#endif
...
...
composable_kernel/include/utility/config.hpp
View file @
31b40352
...
...
@@ -7,19 +7,14 @@
#endif
#include "bfloat16_dev.hpp"
// address space for kernel parameter
//
"Constant"
address space for kernel parameter
#define CONSTANT __attribute__((address_space(4)))
// GPU target
// should enable one and only one GPU target
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
#error Need to define a single GPU target
#endif
// HIP version
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#error Need to define (only) one GPU target
#endif
// launch bounds
...
...
@@ -38,6 +33,16 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif
// FMA instruction
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900)
#define CK_USE_AMD_V_MAC_F32
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \
defined(CK_AMD_GPU_GFX1030)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#endif
// multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
...
...
@@ -46,13 +51,9 @@
#define CK_USE_AMD_INLINE_ASM 1
#endif
// AMD DLOPS
#ifndef CK_USE_AMD_DLOP
#define CK_USE_AMD_DLOP 1
#endif
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
#define CK_USE_AMD_DLOP_INLINE_ASM 1
// AMD inner product (DLOP)
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#endif
// AMD buffer addressing
...
...
@@ -99,8 +100,8 @@
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_
DYNAMIC_
MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_
DYNAMIC_
MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#endif
// workaround for compiler crash when compiling recursive lambda
...
...
@@ -120,15 +121,6 @@
namespace
ck
{
enum
AddressSpaceEnum_t
{
Generic
,
Global
,
Lds
,
Sgpr
,
Vgpr
};
enum
InMemoryDataOperationEnum_t
{
Set
,
...
...
composable_kernel/include/utility/data_type_enum.hpp
View file @
31b40352
...
...
@@ -3,8 +3,7 @@
namespace
ck
{
// this enumerate should be synchronized with include/miopen.h
typedef
enum
enum
DataTypeEnum_t
{
Half
=
0
,
Float
=
1
,
...
...
@@ -14,7 +13,7 @@ typedef enum
BFloat16
=
5
,
Double
=
6
,
Unknown
=
100
,
}
DataTypeEnum_t
;
};
}
// namespace ck
#endif
composable_kernel/include/utility/data_type_helper.hpp
→
composable_kernel/include/utility/data_type_
enum_
helper.hpp
View file @
31b40352
#ifndef CK_DATA_TYPE_HELPER_HPP
#define CK_DATA_TYPE_HELPER_HPP
#ifndef CK_DATA_TYPE_
ENUM_
HELPER_HPP
#define CK_DATA_TYPE_
ENUM_
HELPER_HPP
#include "data_type.hpp"
#include "data_type_enum.hpp"
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
31b40352
#ifndef CK_
DYNAMIC_
BUFFER_HPP
#define CK_
DYNAMIC_
BUFFER_HPP
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
namespace
ck
{
#include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
#include "enable_if.hpp"
#include "amd_buffer_addressing_v2.hpp"
namespace
ck
{
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
,
bool
InvalidElementUseNumericalZeroValue
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
}
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_
offse
t
)
const
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_
elemen
t
)
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
...
...
@@ -44,29 +55,50 @@ struct DynamicBuffer
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
return
amd_buffer_load_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
bool
constexpr
use_amd_buffer_addressing
=
true
;
#else
return
is_valid_offset
?
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
}
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
&&
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
is_valid_offset
?
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
is_valid_element
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
}
else
{
return
is_valid_element
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
invalid_element_value_
};
}
}
}
template
<
typename
X
,
typename
std
::
enable_if
<
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_
offse
t
,
const
X
&
x
)
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_
elemen
t
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
...
...
@@ -78,26 +110,26 @@ struct DynamicBuffer
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_
offse
t
)
if
(
is_valid_
elemen
t
)
{
*
re
inter
pret
_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_po
inter_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
#endif
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
)
{
if
(
is_valid_
offse
t
)
if
(
is_valid_
elemen
t
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*
re
inter
pret
_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_po
inter_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient
...
...
@@ -128,24 +160,24 @@ struct DynamicBuffer
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
re
inter
pret
_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
re
inter
pret
_cast
<
const
int8_t
*>
(
&
x
);
*
c_style_po
inter_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_po
inter_cast
<
const
int8_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x2_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
re
inter
pret
_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
re
inter
pret
_cast
<
const
int16_t
*>
(
&
x
);
*
c_style_po
inter_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_po
inter_cast
<
const
int16_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
re
inter
pret
_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
re
inter
pret
_cast
<
const
int32_t
*>
(
&
x
);
*
c_style_po
inter_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_po
inter_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
...
...
@@ -153,8 +185,8 @@ struct DynamicBuffer
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
re
inter
pret
_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
re
inter
pret
_cast
<
const
int32_t
*>
(
&
x
);
*
c_style_po
inter_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_po
inter_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
...
...
@@ -162,8 +194,8 @@ struct DynamicBuffer
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
re
inter
pret
_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
re
inter
pret
_cast
<
const
int32x2_t
*>
(
&
x
);
*
c_style_po
inter_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_po
inter_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
...
...
@@ -171,22 +203,22 @@ struct DynamicBuffer
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
re
inter
pret
_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
re
inter
pret
_cast
<
const
int32x4_t
*>
(
&
x
);
*
c_style_po
inter_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_po
inter_cast
<
const
int32x4_t
*>
(
&
x
);
}
}
else
{
*
re
inter
pret
_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_po
inter_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
#endif
}
}
else
{
if
(
is_valid_
offse
t
)
if
(
is_valid_
elemen
t
)
{
*
re
inter
pret
_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_po
inter_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
}
}
...
...
@@ -196,12 +228,18 @@ struct DynamicBuffer
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
AddressSpaceEnum_t
BufferAddressSpace
=
AddressSpaceEnum_t
::
Generic
,
typename
T
,
typename
ElementSpaceSize
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
>
{
p
,
element_space_size
};
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
true
>
{
p
,
element_space_size
};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
T
invalid_element_value
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
false
>
{
p
,
element_space_size
,
invalid_element_value
};
}
}
// namespace ck
...
...
composable_kernel/include/utility/enable_if.hpp
0 → 100644
View file @
31b40352
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
namespace
ck
{
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
}
// namespace ck
#endif
composable_kernel/include/utility/inner_product.hpp
0 → 100644
View file @
31b40352
#ifndef CK_INNER_PRODUCT_HPP
#define CK_INNER_PRODUCT_HPP
#include "data_type.hpp"
namespace
ck
{
template
<
typename
TA
,
typename
TB
,
typename
TC
>
__device__
void
inner_product
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
template
<
>
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
asm
volatile
(
"
\n
\
v_mac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
asm
volatile
(
"
\n
\
v_fmac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
c
+=
a
*
b
;
#endif
}
template
<
>
__device__
void
inner_product
<
float2_t
,
float2_t
,
float
>
(
const
float2_t
&
a
,
const
float2_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
inner_product
(
vector_type
<
float
,
2
>
{
a
}.
AsType
<
float
>
()[
I0
],
vector_type
<
float
,
2
>
{
b
}.
AsType
<
float
>
()[
I0
],
c
);
inner_product
(
vector_type
<
float
,
2
>
{
a
}.
AsType
<
float
>
()[
I1
],
vector_type
<
float
,
2
>
{
b
}.
AsType
<
float
>
()[
I1
],
c
);
}
template
<
>
__device__
void
inner_product
<
float4_t
,
float4_t
,
float
>
(
const
float4_t
&
a
,
const
float4_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
inner_product
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I0
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I0
],
c
);
inner_product
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I1
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I1
],
c
);
inner_product
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I2
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I2
],
c
);
inner_product
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I3
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I3
],
c
);
}
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm
volatile
(
"
\n
\
v_dot2_f32_f16 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot2
(
a
,
b
,
c
,
false
);
#endif
#else
const
auto
convert
=
type_convert
<
int32_t
>
{};
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
convert
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
});
#endif
}
template
<
>
__device__
void
inner_product
<
half4_t
,
half4_t
,
float
>
(
const
half4_t
&
a
,
const
half4_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
inner_product
(
vector_type
<
half_t
,
4
>
{
a
}.
AsType
<
half2_t
>
()[
I0
],
vector_type
<
half_t
,
4
>
{
b
}.
AsType
<
half2_t
>
()[
I0
],
c
);
inner_product
(
vector_type
<
half_t
,
4
>
{
a
}.
AsType
<
half2_t
>
()[
I1
],
vector_type
<
half_t
,
4
>
{
b
}.
AsType
<
half2_t
>
()[
I1
],
c
);
}
template
<
>
__device__
void
inner_product
<
half8_t
,
half8_t
,
float
>
(
const
half8_t
&
a
,
const
half8_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
inner_product
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I0
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I0
],
c
);
inner_product
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I1
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I1
],
c
);
inner_product
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I2
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I2
],
c
);
inner_product
(
vector_type
<
half_t
,
8
>
{
a
}.
AsType
<
half2_t
>
()[
I3
],
vector_type
<
half_t
,
8
>
{
b
}.
AsType
<
half2_t
>
()[
I3
],
c
);
}
template
<
>
__device__
void
inner_product
<
int8x4_t
,
int8x4_t
,
int32_t
>
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
{
#if defined(CK_USE_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b
)),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b
),
c
,
false
);
#endif
#else
const
auto
convert
=
type_convert
<
int32_t
>
{};
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
c
+=
convert
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
#endif
}
template
<
>
__device__
void
inner_product
<
int8x8_t
,
int8x8_t
,
int32_t
>
(
const
int8x8_t
&
a
,
const
int8x8_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
inner_product
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
inner_product
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
}
template
<
>
__device__
void
inner_product
<
int8x16_t
,
int8x16_t
,
int32_t
>
(
const
int8x16_t
&
a
,
const
int8x16_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I2
],
c
);
inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I3
],
c
);
}
}
// namespace ck
#endif
composable_kernel/include/utility/math.hpp
View file @
31b40352
...
...
@@ -5,6 +5,7 @@
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace
ck
{
namespace
math
{
...
...
@@ -27,13 +28,7 @@ struct minus
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
-
b
;
}
};
template
<
typename
T
>
struct
multiplies
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
struct
multiplies_v2
{
template
<
typename
A
,
typename
B
>
__host__
__device__
constexpr
auto
operator
()(
const
A
&
a
,
const
B
&
b
)
const
...
...
@@ -184,9 +179,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return
Number
<
r
>
{};
}
template
<
typename
X
,
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
X
,
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
return
gcd
(
x
,
gcd
(
ys
...));
...
...
@@ -199,9 +192,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return
(
x
*
y
)
/
gcd
(
x
,
y
);
}
template
<
typename
X
,
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
X
,
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Ys
...
ys
)
{
return
lcm
(
x
,
lcm
(
ys
...));
...
...
composable_kernel/include/utility/print.hpp
View file @
31b40352
...
...
@@ -11,59 +11,11 @@ namespace ck {
template
<
typename
T
>
__host__
__device__
void
print_array
(
const
char
*
s
,
T
a
)
{
using
data_type
=
decltype
(
a
.
At
(
Number
<
0
>
{}));
constexpr
index_t
nsize
=
a
.
Size
();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, bool>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
printf("}\n");
}
#else
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%d, "
,
int32_t
{
a
[
i
]});
});
printf
(
"}
\n
"
);
#endif
}
template
<
typename
T
>
__host__
__device__
void
print_array_v2
(
const
char
*
s
,
T
a
)
{
using
data_type
=
decltype
(
a
.
At
(
Number
<
0
>
{}));
constexpr
index_t
nsize
=
a
.
Size
();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
}
#else
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"[%d] %d, "
,
i
.
value
,
a
[
i
]);
});
printf
(
"}
\n
"
);
#endif
}
}
// namespace ck
...
...
composable_kernel/include/utility/sequence.hpp
View file @
31b40352
...
...
@@ -685,8 +685,6 @@ __host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
template
<
index_t
Y
,
index_t
...
Xs
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
Y
>
,
Sequence
<
Xs
...
>
)
{
constexpr
auto
seq_x
=
Sequence
<
Xs
...
>
{};
return
Sequence
<
(
Y
-
Xs
)...
>
{};
}
...
...
composable_kernel/include/utility/static_buffer.hpp
View file @
31b40352
...
...
@@ -5,30 +5,66 @@
namespace
ck
{
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
,
bool
InvalidElementUseNumericalZeroValue
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
T
invalid_element_value_
=
T
{
0
};
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBuffer
(
T
invalid_element_value
)
:
base
{},
invalid_element_value_
{
invalid_element_value
}
{
}
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
Get
(
Number
<
I
>
i
,
bool
is_valid_element
)
const
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
is_valid_element
?
At
(
i
)
:
T
{
0
};
}
else
{
return
is_valid_element
?
At
(
i
)
:
invalid_element_value_
;
}
}
template
<
index_t
I
>
__host__
__device__
void
Set
(
Number
<
I
>
i
,
bool
is_valid_element
,
const
T
&
x
)
{
if
(
is_valid_element
)
{
At
(
i
)
=
x
;
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
AddressSpaceEnum_t
BufferAddressSpace
=
AddressSpaceEnum_t
::
Generic
,
typename
T
,
index_t
N
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
>
{};
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
,
true
>
{};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
,
T
invalid_element_value
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
,
false
>
{
invalid_element_value
};
}
}
// namespace ck
...
...
composable_kernel/include/utility/tuple.hpp
View file @
31b40352
...
...
@@ -4,6 +4,7 @@
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace
ck
{
...
...
@@ -20,9 +21,8 @@ struct TupleElement
{
__host__
__device__
constexpr
TupleElement
()
=
default
;
template
<
typename
T
,
typename
std
::
enable_if
<!
is_same
<
remove_reference_t
<
remove_cv_t
<
T
>
>
,
TupleElement
>::
value
,
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_reference_t
<
remove_cv_t
<
T
>
>
,
TupleElement
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
...
...
@@ -58,9 +58,8 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{
__host__
__device__
constexpr
TupleImpl
()
=
default
;
template
<
typename
Y
,
typename
std
::
enable_if
<
sizeof
...(
Is
)
==
1
&&
sizeof
...(
Xs
)
==
1
&&
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Is
)
==
1
&&
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_reference_t
<
remove_cv_t
<
Y
>
>
,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
...
...
@@ -68,7 +67,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{
}
template
<
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
{
...
...
@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
Tuple
()
=
default
;
template
<
typename
Y
,
typename
std
::
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_reference_t
<
remove_cv_t
<
Y
>
>
,
Tuple
>::
value
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_reference_t
<
remove_cv_t
<
Y
>
>
,
Tuple
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
{
}
template
<
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
...
...
composable_kernel/include/utility/type.hpp
View file @
31b40352
...
...
@@ -2,6 +2,7 @@
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "enable_if.hpp"
namespace
ck
{
...
...
@@ -22,10 +23,7 @@ template <typename T>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
typename
T
>
constexpr
std
::
remove_reference_t
<
T
>&&
move
(
T
&&
t
)
noexcept
{
return
static_cast
<
typename
std
::
remove_reference
<
T
>::
type
&&>
(
t
);
}
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
template
<
typename
T
>
struct
is_known_at_compile_time
;
...
...
@@ -42,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static
constexpr
bool
value
=
true
;
};
template
<
typename
Y
,
typename
X
,
typename
std
::
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
as_type
(
X
x
)
{
union
AsType
...
...
composable_kernel/src/kernel_wrapper/
dynamic_
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
→
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
View file @
31b40352
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
gemm_dlops_v1r2.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
using
namespace
ck
;
...
...
@@ -64,8 +64,7 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs
constexpr
bool
HasMainKBlockLoop
=
static_cast
<
bool
>
(
CK_PARAM_HAS_MAIN_KBLOCK_LOOP
);
constexpr
bool
HasDoubleTailKBlockLoop
=
static_cast
<
bool
>
(
CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP
);
extern
"C"
__global__
void
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare
(
extern
"C"
__global__
void
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare
(
int
n
,
int
c
,
int
hi
,
...
...
@@ -93,12 +92,9 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
const
index_t
ho
=
(
hi
+
leftPadH
+
rightPadH
-
convDilationY
*
(
y
-
1
)
-
1
)
/
convStrideH
+
1
;
const
index_t
wo
=
(
wi
+
leftPadW
+
rightPadW
-
convDilationX
*
(
x
-
1
)
-
1
)
/
convStrideW
+
1
;
const
auto
in_n_c_hi_wi_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
n
,
c
,
hi
,
wi
));
const
auto
wei_k_c_y_x_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
k
,
c
,
y
,
x
));
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
n
,
k
,
ho
,
wo
));
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
n
,
c
,
hi
,
wi
));
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k
,
c
,
y
,
x
));
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
n
,
k
,
ho
,
wo
));
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
...
...
@@ -117,7 +113,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
using
BKNGridDesc
=
decltype
(
b_k_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
AGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
...
...
@@ -126,7 +122,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
BGrid
I
te
rator
Hacks
=
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
...
...
@@ -134,7 +130,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
...
@@ -147,11 +143,11 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
GridwiseGemm
=
Gridwise
Dynamic
GemmDlops_km_kn_mn_v1r2
<
BlockSize
,
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
...
...
@@ -188,11 +184,11 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
BGrid
I
te
rator
Hacks
,
CGrid
I
te
rator
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
AGrid
S
te
p
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
>
;
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
auto
b_k_n0_n1_grid_desc
=
GridwiseGemm
::
MakeBKN0N1GridDescriptor
(
b_k_n_grid_desc
);
...
...
@@ -216,7 +212,7 @@ extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
dynamic_
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw
(
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
@@ -230,11 +226,11 @@ extern "C" __global__ void
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
in_n_c_hi_wi_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
256
,
28
,
28
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
256
,
28
,
28
));
constexpr
auto
wei_k_c_y_x_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
256
,
3
,
3
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
256
,
3
,
3
));
constexpr
auto
out_n_k_ho_wo_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
256
,
28
,
28
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
256
,
28
,
28
));
constexpr
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
...
...
@@ -253,7 +249,7 @@ extern "C" __global__ void
using
BKNGridDesc
=
decltype
(
b_k_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
AGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
...
...
@@ -262,7 +258,7 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
BGrid
I
te
rator
Hacks
=
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
...
...
@@ -270,7 +266,7 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
...
@@ -283,11 +279,11 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
GridwiseGemm
=
Gridwise
Dynamic
GemmDlops_km_kn_mn_v1r2
<
BlockSize
,
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
...
...
@@ -324,11 +320,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
BGrid
I
te
rator
Hacks
,
CGrid
I
te
rator
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
AGrid
S
te
p
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
>
;
constexpr
auto
a_k_m0_m1_grid_desc_tmp
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment