Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
3cb2a7d0
Commit
3cb2a7d0
authored
Sep 25, 2019
by
Chao Liu
Browse files
removing old implementation of tensor descriptor
parent
39d92e7d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
30 deletions
+42
-30
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
...cit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
+2
-2
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+16
-13
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+24
-15
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp
View file @
3cb2a7d0
...
@@ -431,9 +431,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
...
@@ -431,9 +431,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
b_thread_data_on_global
,
b_thread_data_on_global
,
0
})
0
})
#if 1
#if 1
.
template
Run_generic
<
Float
,
address_space_t
::
generic
,
address_space_t
::
global
>
.
template
Run_generic
<
Float
,
Float
,
address_space_t
::
generic
,
address_space_t
::
global
>
#elif 1
#elif 1
.
template
Run_optimized_dst_address_calculation
<
Float
,
address_space_t
::
global
>
.
template
Run_optimized_dst_address_calculation
<
Float
,
Float
,
address_space_t
::
global
>
#endif
#endif
(
p_out_thread
,
p_out_global
);
(
p_out_thread
,
p_out_global
);
}
}
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
3cb2a7d0
...
@@ -734,43 +734,46 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -734,43 +734,46 @@ struct BlockwiseGenericTensorSliceCopy_v4
return
RegisterBufferDesc
::
GetElementSpace
();
return
RegisterBufferDesc
::
GetElementSpace
();
}
}
template
<
typename
T
Data
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
>
template
<
typename
SrcData
,
typename
Buffer
Data
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunLoadRegisterBuffer
(
const
T
Data
*
p_src
,
T
Data
*
p_buffer
)
const
__device__
void
RunLoadRegisterBuffer
(
const
Src
Data
*
p_src
,
Buffer
Data
*
p_buffer
)
const
{
{
#if 1
#if 1
mThreadwiseLoad
.
template
Run_generic
<
T
Data
,
SrcAddressSpace
,
address_space_t
::
generic
>(
mThreadwiseLoad
.
template
Run_generic
<
SrcData
,
Buffer
Data
,
SrcAddressSpace
,
address_space_t
::
generic
>(
p_src
,
p_buffer
);
p_src
,
p_buffer
);
#else
#else
mThreadwiseLoad
.
template
Run_optimized_src_address_calculation
<
TData
,
mThreadwiseLoad
.
template
Run_optimized_src_address_calculation
<
SrcData
,
BufferData
,
SrcAddressSpace
,
SrcAddressSpace
,
address_space_t
::
generic
>(
address_space_t
::
generic
>(
p_src
,
p_buffer
);
p_src
,
p_buffer
);
#endif
#endif
}
}
template
<
typename
T
Data
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
template
<
typename
BufferData
,
typename
Dst
Data
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunStoreRegisterBuffer
(
const
T
Data
*
p_buffer
,
T
Data
*
p_dst
)
const
__device__
void
RunStoreRegisterBuffer
(
const
Buffer
Data
*
p_buffer
,
Dst
Data
*
p_dst
)
const
{
{
#if 1
#if 1
mThreadwiseStore
.
template
Run_generic
<
T
Data
,
address_space_t
::
generic
,
DstAddressSpace
>(
mThreadwiseStore
.
template
Run_generic
<
BufferData
,
Dst
Data
,
address_space_t
::
generic
,
DstAddressSpace
>(
p_buffer
,
p_dst
);
p_buffer
,
p_dst
);
#else
#else
mThreadwiseStore
.
template
Run_optimized_dst_address_calculation
<
TData
,
mThreadwiseStore
.
template
Run_optimized_dst_address_calculation
<
BufferData
,
DstData
,
address_space_t
::
generic
,
address_space_t
::
generic
,
DstAddressSpace
>(
p_buffer
,
DstAddressSpace
>(
p_buffer
,
p_dst
);
p_dst
);
#endif
#endif
}
}
template
<
typename
TData
,
template
<
typename
SrcData
,
typename
DstData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run
(
const
T
Data
*
p_src
,
T
Data
*
p_dst
)
const
__device__
void
Run
(
const
Src
Data
*
p_src
,
Dst
Data
*
p_dst
)
const
{
{
T
Data
p_buffer
[
GetRegisterBufferSize
()];
Src
Data
p_
src_
buffer
[
GetRegisterBufferSize
()];
RunLoadRegisterBuffer
<
T
Data
,
SrcAddressSpace
>
(
p_src
,
p_buffer
);
RunLoadRegisterBuffer
<
SrcData
,
Src
Data
,
SrcAddressSpace
>
(
p_src
,
p_buffer
);
RunStoreRegisterBuffer
<
T
Data
,
DstAddressSpace
>
(
p_buffer
,
p_dst
);
RunStoreRegisterBuffer
<
SrcData
,
Dst
Data
,
DstAddressSpace
>
(
p_buffer
,
p_dst
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
template
<
typename
T
,
bool
PositiveDirection
>
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
3cb2a7d0
...
@@ -1179,13 +1179,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1179,13 +1179,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// Will do padding check on dst data: No write if dst data is in paddin area.
template
<
typename
TData
,
template
<
typename
SrcData
,
typename
DstData
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
SrcAddressSpace
=
address_space_t
::
generic
,
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
address_space_t
DstAddressSpace
=
address_space_t
::
generic
>
__device__
void
Run_generic
(
const
T
Data
*
p_src
,
T
Data
*
p_dst
)
const
__device__
void
Run_generic
(
const
Src
Data
*
p_src
,
Dst
Data
*
p_dst
)
const
{
{
using
src_vector_t
=
typename
vector_type
<
T
Data
,
SrcDataPerAccess
>::
MemoryType
;
using
src_vector_t
=
typename
vector_type
<
Src
Data
,
SrcDataPerAccess
>::
MemoryType
;
using
dst_vector_t
=
typename
vector_type
<
T
Data
,
DstDataPerAccess
>::
MemoryType
;
using
dst_vector_t
=
typename
vector_type
<
Dst
Data
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
...
@@ -1205,13 +1206,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1205,13 +1206,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
long_vector_data_begin_id
(
vector_access_dim
)
=
long_vector_data_begin_id
(
vector_access_dim
)
=
long_vector_size
*
long_vector_access_id
[
vector_access_dim
];
long_vector_size
*
long_vector_access_id
[
vector_access_dim
];
// buffer to hold a long-vector
// buffer to hold a
src
long-vector
T
Data
p_long_vector
[
long_vector_size
];
Src
Data
p_
src_
long_vector
[
long_vector_size
];
// zero out buffer
// zero out buffer
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
{
p_long_vector
[
i
]
=
0
;
p_
src_
long_vector
[
i
]
=
0
;
}
}
// load data from src to the long-vector buffer
// load data from src to the long-vector buffer
...
@@ -1231,20 +1232,28 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1231,20 +1232,28 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_
src_
long_vector
[
buffer_offset
])
=
__buffer_load
<
T
Data
,
SrcDataPerAccess
>
(
p_src
,
src_coord
.
GetOffset
(),
0
);
__buffer_load
<
Src
Data
,
SrcDataPerAccess
>
(
p_src
,
src_coord
.
GetOffset
(),
0
);
#else
#else
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_
src_
long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_coord
.
GetOffset
()]);
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_coord
.
GetOffset
()]);
#endif
#endif
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
// src can be all kinds of memory-space.
// src can be all kinds of memory-space.
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
src_vector_t
*>
(
&
p_
src_
long_vector
[
buffer_offset
])
=
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_coord
.
GetOffset
()]);
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_coord
.
GetOffset
()]);
});
});
}
}
}
}
// SrcData to DstData conversion
DstData
p_dst_long_vector
[
long_vector_size
];
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
p_dst_long_vector
[
i
]
=
type_convert
<
DstData
>
(
p_src_long_vector
[
i
]);
}
// store data from the long-vector buffer to dst
// store data from the long-vector buffer to dst
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
dst_data_per_access
;
++
i
)
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
dst_data_per_access
;
++
i
)
{
{
...
@@ -1262,19 +1271,19 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -1262,19 +1271,19 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
__buffer_store
<
T
Data
,
DstDataPerAccess
>
(
__buffer_store
<
Dst
Data
,
DstDataPerAccess
>
(
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_long_vector
[
buffer_offset
]),
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_
dst_
long_vector
[
buffer_offset
]),
p_dst
,
p_dst
,
dst_coord
.
GetOffset
(),
dst_coord
.
GetOffset
(),
0
);
0
);
#else
#else
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_coord
.
GetOffset
()])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_coord
.
GetOffset
()])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_long_vector
[
buffer_offset
]);
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_
dst_
long_vector
[
buffer_offset
]);
#endif
#endif
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
// dst can be all kinds of memory-space
// dst can be all kinds of memory-space
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_coord
.
GetOffset
()])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_dst
[
dst_coord
.
GetOffset
()])
=
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_long_vector
[
buffer_offset
]);
*
reinterpret_cast
<
dst_vector_t
*>
(
&
p_
dst_
long_vector
[
buffer_offset
]);
});
});
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment