Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3321471c
"...composable_kernel.git" did not exist on "9e4429f9c3c6c08da06a65cc880b094850c4cb4e"
Commit
3321471c
authored
Apr 01, 2021
by
Jing Zhang
Browse files
add fp16
parent
0c883faa
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
184 additions
and
14 deletions
+184
-14
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
...sable_kernel/include/utility/amd_buffer_addressing_v2.hpp
+140
-1
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+24
-0
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+6
-6
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+13
-6
No files found.
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
View file @
3321471c
...
@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
...
@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i32"
);
// half
__device__
half_t
__llvm_amdgcn_raw_buffer_load_fp16
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.f16"
);
__device__
half2_t
__llvm_amdgcn_raw_buffer_load_fp16x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2f16"
);
__device__
half4_t
__llvm_amdgcn_raw_buffer_load_fp16x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4f16"
);
// float
__device__
float
__device__
float
__llvm_amdgcn_raw_buffer_load_fp32
(
int32x4_t
srsrc
,
__llvm_amdgcn_raw_buffer_load_fp32
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
voffset
,
...
@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
...
@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i32"
);
// half
__device__
void
__llvm_amdgcn_raw_buffer_store_fp16
(
half_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.f16"
);
__device__
void
__llvm_amdgcn_raw_buffer_store_fp16x2
(
half2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2f16"
);
__device__
void
__llvm_amdgcn_raw_buffer_store_fp16x4
(
half4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f16"
);
// float
__device__
void
__device__
void
__llvm_amdgcn_raw_buffer_store_fp32
(
float
vdata
,
__llvm_amdgcn_raw_buffer_store_fp32
(
float
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
...
@@ -142,6 +183,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -142,6 +183,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t
src_wave_addr_offset
)
index_t
src_wave_addr_offset
)
{
{
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half4_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half8_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32x2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32x2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32x4_t
>::
value
&&
(
N
==
1
)),
(
is_same
<
T
,
int32x4_t
>::
value
&&
(
N
==
1
)),
...
@@ -177,6 +222,55 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -177,6 +222,55 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return
tmp
.
Vector
();
return
tmp
.
Vector
();
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half2_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half4_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
vector_type
<
half_t
,
8
>
tmp
;
tmp
.
Vectors
(
Number
<
4
>
{})(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
Vectors
(
Number
<
4
>
{})(
Number
<
1
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
tmp
.
Vector
();
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
{
if
constexpr
(
N
==
1
)
if
constexpr
(
N
==
1
)
...
@@ -234,7 +328,8 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -234,7 +328,8 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
static_assert
(
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
if
constexpr
(
is_same
<
T
,
float
>::
value
)
...
@@ -334,6 +429,50 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -334,6 +429,50 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
0
);
0
);
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
__llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
__llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
__llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
half_t
,
8
>
tmp
;
tmp
.
Vector
()
=
src_thread_data
;
__llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
Vectors
(
Number
<
4
>
{})[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
__llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
Vectors
(
Number
<
4
>
{})[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
}
}
}
// buffer_load requires:
// buffer_load requires:
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
3321471c
...
@@ -166,6 +166,30 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
...
@@ -166,6 +166,30 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
"3"
(
c3
));
"3"
(
c3
));
}
}
__device__
void
amd_assembly_outer_product_1x4
(
half8_t
a
,
half8_t
b0
,
half8_t
b1
,
half8_t
b2
,
half8_t
b3
,
float
&
c0
,
float
&
c1
,
float
&
c2
,
float
&
c3
)
{
const
half4_t
*
p_a_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
a
);
const
half4_t
*
p_b0_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
b0
);
const
half4_t
*
p_b1_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
b1
);
const
half4_t
*
p_b2_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
b2
);
const
half4_t
*
p_b3_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
b3
);
amd_assembly_outer_product_1x4
(
p_a_half4
[
0
],
p_b0_half4
[
0
],
p_b1_half4
[
0
],
p_b2_half4
[
0
],
p_b3_half4
[
0
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
p_a_half4
[
1
],
p_b0_half4
[
1
],
p_b1_half4
[
1
],
p_b2_half4
[
1
],
p_b3_half4
[
1
],
c0
,
c1
,
c2
,
c3
);
}
// c0 += inner_product(a, b0)
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c1 += inner_product(a, b1)
__device__
void
__device__
void
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
3321471c
...
@@ -53,7 +53,7 @@
...
@@ -53,7 +53,7 @@
// AMD buffer addressing
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING
0
#define CK_USE_AMD_BUFFER_ADDRESSING
1
#endif
#endif
// only gfx908 support native floating point atomic add
// only gfx908 support native floating point atomic add
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
3321471c
...
@@ -118,16 +118,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -118,16 +118,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
64
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
EPerBlock
=
2
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
EPerBlock
;
constexpr
index_t
EPerThread
=
EPerBlock
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
3
,
1
>
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
9
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
3
*
EPerBlock
,
KPerBlock
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
EPerBlock
,
KPerBlock
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
...
@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
//DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
//
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
<
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
<
BlockSize
,
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
...
...
driver/src/conv_driver.cpp
View file @
3321471c
...
@@ -76,14 +76,14 @@ int main(int argc, char* argv[])
...
@@ -76,14 +76,14 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
0
#elif
1
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
K
=
4
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
...
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
540
;
constexpr
index_t
HI
=
540
;
...
@@ -663,12 +663,17 @@ int main(int argc, char* argv[])
...
@@ -663,12 +663,17 @@ int main(int argc, char* argv[])
constexpr index_t in_vector_size = 1;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using acc_data_t = float;
using out_data_t = float;
using out_data_t = float;
#elif
1
using
in_data_t
=
half_t
;
constexpr
index_t
in_vector_size
=
8
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
#elif 0
#elif 0
using
in_data_t
=
float
;
using
in_data_t
=
float
;
constexpr
index_t
in_vector_size
=
1
;
constexpr
index_t
in_vector_size
=
1
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
#elif
1
#elif
0
using
in_data_t
=
int8_t
;
using
in_data_t
=
int8_t
;
constexpr
index_t
in_vector_size
=
16
;
constexpr
index_t
in_vector_size
=
16
;
using
acc_data_t
=
int32_t
;
using
acc_data_t
=
int32_t
;
...
@@ -816,6 +821,7 @@ int main(int argc, char* argv[])
...
@@ -816,6 +821,7 @@ int main(int argc, char* argv[])
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#if 0
if(do_log)
if(do_log)
{
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
...
@@ -823,5 +829,6 @@ int main(int argc, char* argv[])
...
@@ -823,5 +829,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
}
#endif
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment