Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
cf218184
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0b7ade0d4cf1ecc475b9bc94b4a2a96ce093504b"
Commit
cf218184
authored
Sep 30, 2019
by
Chao Liu
Browse files
enable type conversion in blockwise copy v2 and threadwise copy v2r1
parent
012d3a07
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
161 additions
and
138 deletions
+161
-138
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+13
-11
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+2
-0
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp
...ration/blockwise_generic_tensor_slice_copy_deprecated.hpp
+22
-15
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp
...ation/threadwise_generic_tensor_slice_copy_deprecated.hpp
+121
-109
driver/src/driver.cpp
driver/src/driver.cpp
+3
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
cf218184
...
@@ -265,10 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -265,10 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
blockwise_in_copy
.
template
Run
<
Float
,
address_space_t
::
global
>(
p_in_global
,
blockwise_in_copy
.
template
Run
<
Float
,
Float
,
address_space_t
::
global
>(
p_in_block_double
);
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
template
Run
<
Float
,
address_space_t
::
global
>(
p_wei_global
,
blockwise_wei_copy
.
template
Run
<
Float
,
Float
,
address_space_t
::
global
>(
p_wei_block_double
);
p_wei_global
,
p_wei_block_double
);
}
}
// LDS double buffer: main body
// LDS double buffer: main body
...
@@ -299,10 +299,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -299,10 +299,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
template
RunLoadThreadBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_in_copy
p_in_global
,
p_in_thread_buffer
);
.
template
RunLoadThreadBuffer
<
Float
,
Float
,
address_space_t
::
global
>(
blockwise_wei_copy
.
template
RunLoadThreadBuffer
<
Float
,
address_space_t
::
global
>(
p_in_global
,
p_in_thread_buffer
);
p_wei_global
,
p_wei_thread_buffer
);
blockwise_wei_copy
.
template
RunLoadThreadBuffer
<
Float
,
Float
,
address_space_t
::
global
>(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
...
@@ -325,9 +327,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -325,9 +327,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
template
RunLoadThreadBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_in_copy
.
template
RunLoadThreadBuffer
<
Float
,
Float
,
address_space_t
::
global
>(
p_in_global
,
p_in_thread_buffer
);
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
template
RunLoadThreadBuffer
<
Float
,
address_space_t
::
global
>(
blockwise_wei_copy
.
template
RunLoadThreadBuffer
<
Float
,
Float
,
address_space_t
::
global
>(
p_wei_global
,
p_wei_thread_buffer
);
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
...
@@ -396,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -396,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
0
,
0
,
b_thread_data_on_global
,
b_thread_data_on_global
,
0
})
0
})
.
template
Run
<
Float
,
address_space_t
::
generic
,
address_space_t
::
global
>(
.
template
Run
<
Float
,
Float
,
address_space_t
::
generic
,
address_space_t
::
global
>(
p_out_thread
,
p_out_global
);
p_out_thread
,
p_out_global
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
cf218184
...
@@ -120,6 +120,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -120,6 +120,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData
,
BlockSrcData
,
BlockSrcAddressSpace
,
BlockSrcAddressSpace
,
address_space_t
::
generic
>
(
p_block_src
,
p_thread_buffer
);
address_space_t
::
generic
>
(
p_block_src
,
p_thread_buffer
);
// if there is type conversion, it's done during store
RunStoreThreadBuffer
<
BlockSrcData
,
RunStoreThreadBuffer
<
BlockSrcData
,
BlockDstData
,
BlockDstData
,
address_space_t
::
generic
,
address_space_t
::
generic
,
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp
View file @
cf218184
...
@@ -478,35 +478,42 @@ struct BlockwiseGenericTensorSliceCopy_v2
...
@@ -478,35 +478,42 @@ struct BlockwiseGenericTensorSliceCopy_v2
return
ThreadBufferDesc
::
GetElementSpace
();
return
ThreadBufferDesc
::
GetElementSpace
();
}
}
template
<
typename
TData
,
template
<
typename
SrcData
,
typename
DstData
,
address_space_t
BlockSrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
BlockSrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
ThreadBufferAddressSpace
=
address_space_t
::
generic
>
address_space_t
ThreadBufferAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunLoadThreadBuffer
(
const
T
Data
*
p_block_src
,
T
Data
*
p_thread_buffer
)
const
__device__
void
RunLoadThreadBuffer
(
const
Src
Data
*
p_block_src
,
Dst
Data
*
p_thread_buffer
)
const
{
{
mThreadwiseLoad
.
template
Run
<
TData
,
BlockSrcAddressSpace
,
ThreadBufferAddressSpace
>(
mThreadwiseLoad
p_block_src
,
p_thread_buffer
);
.
template
Run
<
SrcData
,
DstData
,
BlockSrcAddressSpace
,
ThreadBufferAddressSpace
>(
p_block_src
,
p_thread_buffer
);
}
}
template
<
typename
TData
,
template
<
typename
SrcData
,
typename
DstData
,
address_space_t
ThreadBufferAddressSpace
=
address_space_t
::
generic
,
address_space_t
ThreadBufferAddressSpace
=
address_space_t
::
generic
,
address_space_t
BlockDstAddressSpace
=
address_space_t
::
generic
>
address_space_t
BlockDstAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunStoreThreadBuffer
(
const
T
Data
*
p_thread_buffer
,
T
Data
*
p_block_dst
)
const
__device__
void
RunStoreThreadBuffer
(
const
Src
Data
*
p_thread_buffer
,
Dst
Data
*
p_block_dst
)
const
{
{
mThreadwiseStore
.
template
Run
<
TData
,
ThreadBufferAddressSpace
,
BlockDstAddressSpace
>(
mThreadwiseStore
p_thread_buffer
,
p_block_dst
);
.
template
Run
<
SrcData
,
DstData
,
ThreadBufferAddressSpace
,
BlockDstAddressSpace
>(
p_thread_buffer
,
p_block_dst
);
}
}
template
<
typename
TData
,
template
<
typename
SrcData
,
typename
DstData
,
address_space_t
BlockSrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
BlockSrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
BlockDstAddressSpace
=
address_space_t
::
generic
>
address_space_t
BlockDstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run
(
const
T
Data
*
p_block_src
,
T
Data
*
p_block_dst
)
const
__device__
void
Run
(
const
Src
Data
*
p_block_src
,
Dst
Data
*
p_block_dst
)
const
{
{
T
Data
p_thread_buffer
[
GetThreadBufferSize
()];
Src
Data
p_thread_buffer
[
GetThreadBufferSize
()];
RunLoadThreadBuffer
<
TData
,
BlockSrcAddressSpace
,
address_space_t
::
generic
>
(
p_block_src
,
RunLoadThreadBuffer
<
SrcData
,
SrcData
,
BlockSrcAddressSpace
,
address_space_t
::
generic
>
(
p_thread_buffer
);
p_block_src
,
p_thread_buffer
);
RunStoreThreadBuffer
<
TData
,
address_space_t
::
generic
,
BlockDstAddressSpace
>
(
p_thread_buffer
,
p_block_dst
);
// if there is type conversion, it's done during store
RunStoreThreadBuffer
<
SrcData
,
DstData
,
address_space_t
::
generic
,
BlockDstAddressSpace
>
(
p_thread_buffer
,
p_block_dst
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
template
<
typename
T
,
bool
PositiveDirection
>
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp
View file @
cf218184
...
@@ -537,19 +537,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -537,19 +537,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
}
}
};
};
template
<
typename
TData
,
template
<
typename
SrcData
,
typename
DstData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run
(
const
T
Data
*
p_src
,
T
Data
*
p_dst
)
const
__device__
void
Run
(
const
Src
Data
*
p_src
,
Dst
Data
*
p_dst
)
const
{
{
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SliceLengths
{});
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SliceLengths
{});
T
Data
p_buffer_
[
buffer_desc
.
GetElementSpace
()];
Src
Data
p_
src_
buffer_
[
buffer_desc
.
GetElementSpace
()];
T
Data
*
p_buffer
=
p_buffer_
;
Src
Data
*
p_
src_
buffer
=
p_
src_
buffer_
;
// copy data from src into buffer
// copy data from src into buffer
{
{
using
src_vector_t
=
typename
vector_type
<
T
Data
,
SrcDataPerAccess
>::
MemoryType
;
using
src_vector_t
=
typename
vector_type
<
Src
Data
,
SrcDataPerAccess
>::
MemoryType
;
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
...
@@ -573,77 +574,88 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -573,77 +574,88 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr
auto
src_normal_dim_access_lengths
=
constexpr
auto
src_normal_dim_access_lengths
=
src_access_lengths
+
Number
<
1
>
{}
-
src_merged_dim_access_lengths
;
src_access_lengths
+
Number
<
1
>
{}
-
src_merged_dim_access_lengths
;
ford
<
decltype
(
src_merged_dim_access_lengths
),
SrcDimAccessOrder
>
{}(
[
&
](
ford
<
decltype
(
src_merged_dim_access_lengths
),
SrcDimAccessOrder
>
{}(
auto
src_merged_dim_access_id
)
{
[
&
](
auto
src_merged_dim_access_id
)
{
auto
src_merged_dim_data_id
=
src_merged_dim_access_id
;
auto
src_merged_dim_data_id
=
src_merged_dim_access_id
;
src_merged_dim_data_id
(
src_vector_access_dim
)
=
src_merged_dim_data_id
(
src_vector_access_dim
)
=
src_merged_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
src_merged_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
// offset w.r.t. merged dimension need be computed at run-time,
// offset w.r.t. merged dimension need be computed at run-time,
const
index_t
src_merged_offset
=
const
index_t
src_merged_offset
=
(
mSrcSliceOrigin
+
src_merged_dim_data_id
).
GetOffset
();
(
mSrcSliceOrigin
+
src_merged_dim_data_id
).
GetOffset
();
ford
<
decltype
(
src_normal_dim_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
ford
<
decltype
(
src_normal_dim_access_lengths
),
SrcDimAccessOrder
>
{}([
&
](
auto
src_normal_dim_access_id
)
{
auto
src_normal_dim_access_id
)
{
auto
src_normal_dim_data_id
=
src_normal_dim_access_id
;
auto
src_normal_dim_data_id
=
src_normal_dim_access_id
;
src_normal_dim_data_id
(
src_vector_access_dim
)
=
src_normal_dim_data_id
(
src_vector_access_dim
)
=
src_normal_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
src_normal_dim_access_id
[
src_vector_access_dim
]
*
src_data_per_access
;
// offset w.r.t. normal dimension is known at compile-time
// offset w.r.t. normal dimension is known at compile-time
const
index_t
src_normal_offset
=
const
index_t
src_normal_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_normal_dim_data_id
);
SrcDesc
::
GetOffsetFromMultiIndex
(
src_normal_dim_data_id
);
src_vector_t
vector_data
;
src_vector_t
vector_data
;
// Read vector from src.
// Read vector from src.
// 1. Source code version can take src of all kinds of memory-space
// 1. Source code version can take src of all kinds of memory-space
// 2. Intrinsic version using buffer_load can only take
// 2. Intrinsic version using buffer_load can only take
// src from global-memory
// src from global-memory
//
//
// Commemt for loading from global-memory:
// Commemt for loading from global-memory:
// When:
// When:
// 1) using source code, in order for compiler to emit optimal
// 1) using source code, in order for compiler to emit optimal
// load instruction, or
// load instruction, or
// 2) using buffer_load intrinsic, in order for ISA to be valid,
// 2) using buffer_load intrinsic, in order for ISA to be valid,
// following assumptions need to be satisfied:
// following assumptions need to be satisfied:
// 1. p_src need to be block-invariant (assumption)
// 1. p_src need to be block-invariant (assumption)
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// algorithm)
// 3. src_merged_offset can be runtime value (no assumption imposed)
// 3. src_merged_offset can be runtime value (no assumption imposed)
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
vector_data
=
__buffer_load
<
T
Data
,
SrcDataPerAccess
>
(
vector_data
=
__buffer_load
<
Src
Data
,
SrcDataPerAccess
>
(
p_src
,
src_merged_offset
,
src_normal_offset
);
p_src
,
src_merged_offset
,
src_normal_offset
);
#else
#else
vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_normal_offset
+
src_merged_offset
]);
&
p_src
[
src_normal_offset
+
src_merged_offset
]);
#endif
#endif
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
// src can be all kinds of memory-space.
// src can be all kinds of memory-space.
vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
vector_data
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_normal_offset
+
src_merged_offset
]);
&
p_src
[
src_normal_offset
+
src_merged_offset
]);
});
});
// unpack vector into buffer
// unpack vector into buffer
for
(
index_t
i
=
0
;
i
<
SrcDataPerAccess
;
++
i
)
for
(
index_t
i
=
0
;
i
<
SrcDataPerAccess
;
++
i
)
{
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
src_vector_access_dim
)
=
i
;
scalar_id
(
src_vector_access_dim
)
=
i
;
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
src_merged_dim_data_id
+
src_normal_dim_data_id
+
scalar_id
);
src_merged_dim_data_id
+
src_normal_dim_data_id
+
scalar_id
);
p_buffer
[
buffer_offset
]
=
reinterpret_cast
<
const
TData
*>
(
&
vector_data
)[
i
];
p_src_buffer
[
buffer_offset
]
=
}
reinterpret_cast
<
const
SrcData
*>
(
&
vector_data
)[
i
];
}
});
});
});
});
}
}
// type conversion
// TODO: would compiler do a good job reusing register for buffer?
DstData
p_dst_buffer_
[
buffer_desc
.
GetElementSpace
()];
DstData
*
p_dst_buffer
=
p_dst_buffer_
;
ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
p_dst_buffer
[
buffer_desc
.
GetOffsetFromMultiIndex
(
idx
)]
=
type_convert
<
DstData
>
{}(
p_src_buffer
[
buffer_desc
.
GetOffsetFromMultiIndex
(
idx
)]);
});
// copy data from buffer into dst
// copy data from buffer into dst
{
{
using
dst_vector_t
=
typename
vector_type
<
T
Data
,
DstDataPerAccess
>::
MemoryType
;
using
dst_vector_t
=
typename
vector_type
<
Src
Data
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
...
@@ -659,72 +671,72 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -659,72 +671,72 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr
auto
dst_normal_dim_access_lengths
=
constexpr
auto
dst_normal_dim_access_lengths
=
dst_access_lengths
+
Number
<
1
>
{}
-
dst_merged_dim_access_lengths
;
dst_access_lengths
+
Number
<
1
>
{}
-
dst_merged_dim_access_lengths
;
ford
<
decltype
(
dst_merged_dim_access_lengths
),
DstDimAccessOrder
>
{}(
ford
<
decltype
(
dst_merged_dim_access_lengths
),
DstDimAccessOrder
>
{}(
[
&
](
[
&
](
auto
dst_merged_dim_access_id
)
{
auto
dst_merged_dim_access_id
)
{
auto
dst_merged_dim_data_id
=
dst_merged_dim_access_id
;
auto
dst_merged_dim_data_id
=
dst_merged_dim_access_id
;
dst_merged_dim_data_id
(
dst_vector_access_dim
)
=
dst_merged_dim_data_id
(
dst_vector_access_dim
)
=
dst_merged_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
dst_merged_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
// offset w.r.t. merged dimension need be computed at run-time,
// offset w.r.t. merged dimension need be computed at run-time,
const
index_t
dst_merged_offset
=
const
index_t
dst_merged_offset
=
(
mDstSliceOrigin
+
dst_merged_dim_data_id
).
GetOffset
();
(
mDstSliceOrigin
+
dst_merged_dim_data_id
).
GetOffset
();
ford
<
decltype
(
dst_normal_dim_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
ford
<
decltype
(
dst_normal_dim_access_lengths
),
DstDimAccessOrder
>
{}([
&
](
auto
dst_normal_dim_access_id
)
{
auto
dst_normal_dim_access_id
)
{
auto
dst_normal_dim_data_id
=
dst_normal_dim_access_id
;
auto
dst_normal_dim_data_id
=
dst_normal_dim_access_id
;
dst_normal_dim_data_id
(
dst_vector_access_dim
)
=
dst_normal_dim_data_id
(
dst_vector_access_dim
)
=
dst_normal_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
dst_normal_dim_access_id
[
dst_vector_access_dim
]
*
dst_data_per_access
;
dst_vector_t
vector_data
;
dst_vector_t
vector_data
;
// pack vector from buffer
// pack vector from buffer
for
(
index_t
i
=
0
;
i
<
DstDataPerAccess
;
++
i
)
for
(
index_t
i
=
0
;
i
<
DstDataPerAccess
;
++
i
)
{
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
scalar_id
(
dst_vector_access_dim
)
=
i
;
scalar_id
(
dst_vector_access_dim
)
=
i
;
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
const
index_t
buffer_offset
=
buffer_desc
.
GetOffsetFromMultiIndex
(
dst_merged_dim_data_id
+
dst_normal_dim_data_id
+
scalar_id
);
dst_merged_dim_data_id
+
dst_normal_dim_data_id
+
scalar_id
);
reinterpret_cast
<
T
Data
*>
(
&
vector_data
)[
i
]
=
p_buffer
[
buffer_offset
];
reinterpret_cast
<
Src
Data
*>
(
&
vector_data
)[
i
]
=
p_
dst_
buffer
[
buffer_offset
];
}
}
// offset w.r.t. normal dimension is known at compile-time
// offset w.r.t. normal dimension is known at compile-time
const
index_t
dst_normal_offset
=
const
index_t
dst_normal_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_normal_dim_data_id
);
DstDesc
::
GetOffsetFromMultiIndex
(
dst_normal_dim_data_id
);
// Write vector into dst.
// Write vector into dst.
// 1. Source code version can take dst of all kinds of memory-space
// 1. Source code version can take dst of all kinds of memory-space
// 2. Intrinsic version using buffer_store can only take
// 2. Intrinsic version using buffer_store can only take
// dst from global-memory
// dst from global-memory
//
//
// Commemt for storing into global-memory:
// Commemt for storing into global-memory:
// When:
// When:
// 1) using source code, in order for compiler to emit optimal
// 1) using source code, in order for compiler to emit optimal
// store instruction, or
// store instruction, or
// 2) using buffer_store, intrinsic in order ISA to be valid
// 2) using buffer_store, intrinsic in order ISA to be valid
// following assumptions need to be satisfied:
// following assumptions need to be satisfied:
// 1. p_dst need to be block-invariant (assumption)
// 1. p_dst need to be block-invariant (assumption)
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// algorithm)
// 3. dst_merged_offset can be runtime value (no assumption imposed)
// 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
__buffer_store
<
T
Data
,
DstDataPerAccess
>
(
__buffer_store
<
Src
Data
,
DstDataPerAccess
>
(
vector_data
,
p_dst
,
dst_merged_offset
,
dst_normal_offset
);
vector_data
,
p_dst
,
dst_merged_offset
,
dst_normal_offset
);
#else
#else
*
reinterpret_cast
<
dst_vector_t
*>
(
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_normal_offset
+
dst_merged_offset
])
=
vector_data
;
&
p_dst
[
dst_normal_offset
+
dst_merged_offset
])
=
vector_data
;
#endif
#endif
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
// dst can be all kinds of memory-space
// dst can be all kinds of memory-space
*
reinterpret_cast
<
dst_vector_t
*>
(
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_normal_offset
+
dst_merged_offset
])
=
vector_data
;
&
p_dst
[
dst_normal_offset
+
dst_merged_offset
])
=
vector_data
;
});
});
});
});
});
});
}
}
}
}
...
...
driver/src/driver.cpp
View file @
cf218184
...
@@ -295,7 +295,7 @@ int main(int argc, char* argv[])
...
@@ -295,7 +295,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
...
@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -438,7 +438,7 @@ int main(int argc, char* argv[])
...
@@ -438,7 +438,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment