Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
841b1480
Commit
841b1480
authored
Apr 17, 2021
by
Chao Liu
Browse files
replacing array with vector for tensor data
parent
e4790c25
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
11 deletions
+31
-11
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+2
-4
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+0
-4
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+5
-3
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+23
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+1
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
841b1480
...
...
@@ -281,10 +281,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_a_thread
);
...
...
@@ -324,7 +322,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
}
}
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
841b1480
...
...
@@ -265,7 +265,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
index_t
b_block_data_begin
=
0
;
#if 1
if
constexpr
(
HasMainKBlockLoop
)
{
FloatAB
*
p_b_thread_even
=
p_b_thread_double
;
...
...
@@ -350,9 +349,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_thread_double
,
p_c_thread
);
}
#endif
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
...
...
@@ -385,7 +382,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_c_global
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
#endif
}
// pass tensor descriptor by reference
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
841b1480
...
...
@@ -875,7 +875,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
src_data_idx
+
i
*
src_scalar_step_in_vector
);
buffer_
(
Number
<
buffer_offset
>
{})
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
buffer_
.
template
AsType
<
SrcData
>()(
Number
<
buffer_offset
>
{})
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
@@ -1032,7 +1033,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
buffer_
.
template
AsType
<
DstData
>()[
Number
<
buffer_offset
>
{}];
});
using
DstVectorType
=
...
...
@@ -1297,7 +1299,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
Static
allyIndexedArray
<
SrcData
,
buffer_size_
>
buffer_
;
Static
Buffer
<
SrcData
,
buffer_size_
>
buffer_
;
SrcCoord
src_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
...
...
composable_kernel/include/utility/buffer.hpp
0 → 100644
View file @
841b1480
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "float_type.hpp"
namespace
ck
{
template
<
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
vector_type_maker
<
T
,
N
>::
type
{
using
base
=
typename
vector_type_maker
<
T
,
N
>::
type
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
};
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
T
,
N
>
{};
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
841b1480
...
...
@@ -7,6 +7,7 @@
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment