Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c3bcb6ae
Commit
c3bcb6ae
authored
Jun 04, 2019
by
Jing Zhang
Browse files
asm
parent
917d7a2b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
271 additions
and
17 deletions
+271
-17
driver/driver.hip.cpp
driver/driver.hip.cpp
+2
-2
src/include/amd_inline_asm.hip.hpp
src/include/amd_inline_asm.hip.hpp
+132
-11
src/include/blockwise_generic_tensor_slice_op.hip.hpp
src/include/blockwise_generic_tensor_slice_op.hip.hpp
+3
-2
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
...implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
+4
-2
src/include/threadwise_generic_tensor_slice_op.hip.hpp
src/include/threadwise_generic_tensor_slice_op.hip.hpp
+130
-0
No files found.
driver/driver.hip.cpp
View file @
c3bcb6ae
...
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
...
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 3x3 filter, 28x28 image
// 3x3 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
...
@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 28x28 image
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
...
...
src/include/amd_inline_asm.hip.hpp
View file @
c3bcb6ae
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#define NO_GLB_READ 0
#define NO_GLB_READ 0
// cast a pointer of LDS to its address
// cast a pointer of LDS to its address
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
const
void
*
p
)[[
hc
]];
__device__
void
vmcnt
(
index_t
cnt
)
__device__
void
vmcnt
(
index_t
cnt
)
{
{
...
@@ -721,18 +721,111 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -721,18 +721,111 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
#endif
#endif
}
}
__device__
void
global_load
(
vector_type
<
float
,
4
>::
MemoryType
&
r
,
__device__
void
global_loadx4
(
void
*
r
,
const
void
*
ptr
,
index_t
offset
=
0
)
const
vector_type
<
float
,
4
>::
MemoryType
*
ptr
,
index_t
offset
=
0
)
{
{
#if !NO_GLB_READ
#if !NO_GLB_READ
if
(
offset
==
0
)
if
(
offset
==
0
)
{
{
//*(vector_type<float, 4>::MemoryType*)(r) = *(vector_type<float, 4>::MemoryType*)(ptr);
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
global_load_dwordx4 %0, %1, off
\n
\
global_load_dwordx4 %0, %1, off
\n
\
"
"
:
"=v"
(
r
)
:
"=v"
(
*
(
vector_type
<
float
,
4
>::
MemoryType
*
)(
r
))
:
"v"
(
ptr
));
:
"r"
(
ptr
));
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
global_loadx2
(
void
*
r
,
const
void
*
ptr
,
index_t
offset
=
0
)
{
#if !NO_GLB_READ
if
(
offset
==
0
)
{
asm
volatile
(
"
\n
\
global_load_dwordx2 %0, %1, off
\n
\
"
:
"=v"
(
*
(
vector_type
<
float
,
2
>::
MemoryType
*
)(
r
))
:
"r"
(
ptr
));
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
global_loadx1
(
void
*
r
,
const
void
*
ptr
,
index_t
offset
=
0
)
{
#if !NO_GLB_READ
if
(
offset
==
0
)
{
//*(float*)(r) = *(float*)(ptr);
asm
volatile
(
"
\n
\
global_load_dword %0, %1, off
\n
\
"
:
"=v"
(
*
(
float
*
)(
r
))
:
"r"
(
ptr
));
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
global_storex4
(
const
void
*
ptr
,
const
void
*
r
,
index_t
offset
=
0
)
{
#if !NO_GLB_READ
if
(
offset
==
0
)
{
//*(vector_type<float, 4>::MemoryType*)(ptr) = *(vector_type<float, 4>::MemoryType*)(r);
asm
volatile
(
"
\n
\
global_store_dwordx4 %0, %1, off
\n
\
"
:
:
"r"
(
ptr
),
"v"
(
*
(
vector_type
<
float
,
4
>::
MemoryType
*
)(
r
)));
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
global_storex2
(
const
void
*
ptr
,
const
void
*
r
,
index_t
offset
=
0
)
{
#if !NO_GLB_READ
if
(
offset
==
0
)
{
asm
volatile
(
"
\n
\
global_store_dwordx2 %0, %1, off
\n
\
"
:
:
"r"
(
ptr
),
"v"
(
*
(
vector_type
<
float
,
2
>::
MemoryType
*
)(
r
)));
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
global_storex1
(
const
void
*
ptr
,
const
void
*
r
,
index_t
offset
=
0
)
{
#if !NO_GLB_READ
if
(
offset
==
0
)
{
//*(float*)(ptr) = *(float*)(r);
asm
volatile
(
"
\n
\
global_store_dword %0, %1, off
\n
\
"
:
:
"r"
(
ptr
),
"v"
(
*
(
float
*
)(
r
)));
}
}
else
else
{
{
...
@@ -741,17 +834,36 @@ __device__ void global_load(vector_type<float, 4>::MemoryType& r,
...
@@ -741,17 +834,36 @@ __device__ void global_load(vector_type<float, 4>::MemoryType& r,
#endif
#endif
}
}
__device__
void
__device__
void
ds_write_b128
(
const
void
*
lds
,
const
void
*
r
,
index_t
offset
=
0
)
ds_write_b128
(
const
vector_type
<
float
,
4
>::
MemoryType
&
r
,
void
*
lds
,
index_t
offset
=
0
)
{
{
#if !NO_DS_WRITE
#if !NO_DS_WRITE
if
(
offset
==
0
)
if
(
offset
==
0
)
{
{
//*(vector_type<float, 4>::MemoryType*)(lds) = *(vector_type<float, 4>::MemoryType*)(r);
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
ds_write_b128 %0, %1
\n
\
ds_write_b128 %0, %1
\n
\
"
"
:
:
:
"v"
(
__to_local
(
lds
)),
"v"
(
r
));
:
"v"
(
__to_local
(
lds
)),
"v"
(
*
(
vector_type
<
float
,
4
>::
MemoryType
*
)(
r
)));
}
else
{
assert
(
false
);
}
#endif
}
__device__
void
ds_write_b32
(
const
void
*
lds
,
const
void
*
r
,
index_t
offset
=
0
)
{
#if !NO_DS_WRITE
if
(
offset
==
0
)
{
//*(float*)(lds) = *(float*)(r);
asm
volatile
(
"
\n
\
ds_write_b32 %0, %1
\n
\
"
:
:
"v"
(
__to_local
(
lds
)),
"v"
(
*
(
float
*
)(
r
)));
}
}
else
else
{
{
...
@@ -759,3 +871,12 @@ ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t off
...
@@ -759,3 +871,12 @@ ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t off
}
}
#endif
#endif
}
}
__device__
void
s_barrier
()
{
asm
volatile
(
"
\n
\
s_barrier
\n
\
"
:
:
);
}
src/include/blockwise_generic_tensor_slice_op.hip.hpp
View file @
c3bcb6ae
...
@@ -217,7 +217,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -217,7 +217,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
const
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
// cannot not constexpr, why?
clipboard_data_multi_id_begin
);
// cannot not constexpr, why?
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
{},
threadwise_generic_tensor_slice_copy_v1
_asm
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
p_src
+
src_offset
+
mThreadSrcOffset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
thread_tensor_desc
,
thread_tensor_desc
,
...
@@ -225,7 +225,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -225,7 +225,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
thread_sub_tensor_lengths
,
thread_sub_tensor_lengths
,
SrcAccessOrder
{},
SrcAccessOrder
{},
Number
<
SrcDataPerRead
>
{});
Number
<
SrcDataPerRead
>
{},
Number
<
1
>
{});
});
});
}
}
...
...
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
View file @
c3bcb6ae
...
@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if
1
#if
0
return blockwise_gemm.Run(Xs...);
return blockwise_gemm.Run(Xs...);
#else
#else
return
blockwise_gemm
.
Run_asm
(
Xs
...);
return
blockwise_gemm
.
Run_asm
(
Xs
...);
...
@@ -295,7 +295,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -295,7 +295,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
vmcnt
(
0
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
p_wei_block_double
);
...
@@ -336,6 +336,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -336,6 +336,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
run_blockwise_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
run_blockwise_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
vmcnt
(
0
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
p_in_block_next
);
...
@@ -364,6 +365,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -364,6 +365,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
run_blockwise_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
run_blockwise_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
vmcnt
(
0
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
...
...
src/include/threadwise_generic_tensor_slice_op.hip.hpp
View file @
c3bcb6ae
...
@@ -97,3 +97,133 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -97,3 +97,133 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
});
});
#endif
#endif
}
}
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SliceLengths
,
class
DimAccessOrder
,
index_t
DataPerAccess
,
index_t
OpType
>
__device__
void
threadwise_generic_tensor_slice_copy_v1_asm
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_id_begin
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_id_begin
,
SliceLengths
,
DimAccessOrder
,
Number
<
DataPerAccess
>
,
Number
<
OpType
>
)
{
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
DimAccessOrder
::
GetSize
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>::
value
,
"wrong! map is not valid"
);
#if 0
// doesn't compile, because merged-tensor reordering is not implemented
// TODO: implement tensor desc ops for merged-tensor
constexpr auto src_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
constexpr auto dst_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
// check src/dst stride on the lowest access dimension
static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
(DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
"wrong! src/dst stride on the lowest access dimension needs to be 1 for "
"vectorized read/write");
#endif
constexpr
auto
slice_lengths_in_access_order
=
SliceLengths
::
ReorderGivenNew2Old
(
DimAccessOrder
{});
// check slice length on the lowest access dimension
static_assert
(
slice_lengths_in_access_order
.
Back
()
%
DataPerAccess
==
0
,
"wrong! slice length on the lowest access dimension should be evenly divided by "
"DataPerAccess"
);
constexpr
index_t
num_access_on_lowest_access_dimension
=
slice_lengths_in_access_order
.
Back
()
/
DataPerAccess
;
constexpr
auto
access_lengths
=
slice_lengths_in_access_order
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
num_access_on_lowest_access_dimension
>
{});
using
vector_t
=
typename
vector_type
<
Float
,
DataPerAccess
>::
MemoryType
;
#if 1
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
auto
data_multi_id_in_access_order
=
access_multi_id
;
data_multi_id_in_access_order
[
nDim
-
1
]
=
access_multi_id
[
nDim
-
1
]
*
DataPerAccess
;
const
auto
data_multi_id
=
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
const
index_t
src_index
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
const
index_t
dst_index
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
static_assert
(
DataPerAccess
==
1
||
DataPerAccess
==
4
,
"unsupported DataPerAccess"
);
static_assert
(
OpType
==
1
||
OpType
==
2
||
OpType
==
4
,
"unsupported OpType"
);
if
(
DataPerAccess
==
4
)
{
if
(
OpType
==
1
)
{
global_loadx4
(
&
p_dst
[
dst_index
],
&
p_src
[
src_index
]);
}
else
if
(
OpType
==
2
)
{
global_storex4
(
&
p_dst
[
dst_index
],
&
p_src
[
src_index
]);
}
else
{
ds_write_b128
(
&
p_dst
[
dst_index
],
&
p_src
[
src_index
]);
}
}
if
(
DataPerAccess
==
1
)
{
if
(
OpType
==
1
)
{
global_loadx1
(
&
p_dst
[
dst_index
],
&
p_src
[
src_index
]);
}
else
if
(
OpType
==
2
)
{
global_storex1
(
&
p_dst
[
dst_index
],
&
p_src
[
src_index
]);
}
else
{
ds_write_b32
(
&
p_dst
[
dst_index
],
&
p_src
[
src_index
]);
}
}
});
#else
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id_
)
{
const
auto
access_multi_id
=
sequence2array
(
access_multi_id_
);
auto
data_multi_id_in_access_order
=
access_multi_id
;
data_multi_id_in_access_order
[
nDim
-
1
]
=
access_multi_id
[
nDim
-
1
]
*
DataPerAccess
;
const
auto
data_multi_id
=
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
const
index_t
src_index
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
const
index_t
dst_index
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
});
#endif
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment