Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f6ec737c
Commit
f6ec737c
authored
Feb 20, 2021
by
Chao Liu
Browse files
changed blockwise and threadwise gemm to use new tensor descriptor
parent
a7c587ee
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
866 additions
and
1843 deletions
+866
-1843
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform.hpp
...rnel/include/kernel_algorithm/dummy_dynamic_transform.hpp
+0
-544
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v1.hpp
...l/include/kernel_algorithm/dummy_dynamic_transform_v1.hpp
+0
-626
composable_kernel/include/kernel_algorithm/dummy_static_transform.hpp
...ernel/include/kernel_algorithm/dummy_static_transform.hpp
+0
-124
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+361
-22
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+21
-34
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+319
-17
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+155
-0
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+10
-2
driver/include/device_dummy_dynamic_transform.hpp
driver/include/device_dummy_dynamic_transform.hpp
+0
-198
driver/include/device_dummy_dynamic_transform_v1.hpp
driver/include/device_dummy_dynamic_transform_v1.hpp
+0
-140
driver/include/device_dummy_static_transform.hpp
driver/include/device_dummy_static_transform.hpp
+0
-97
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+0
-39
No files found.
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform.hpp
deleted
100644 → 0
View file @
a7c587ee
This diff is collapsed.
Click to expand it.
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v1.hpp
deleted
100644 → 0
View file @
a7c587ee
This diff is collapsed.
Click to expand it.
composable_kernel/include/kernel_algorithm/dummy_static_transform.hpp
deleted
100644 → 0
View file @
a7c587ee
#ifndef CK_DUMMY_STATIC_TRANSFORM_HPP
#define CK_DUMMY_STATIC_TRANSFORM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
struct
DummyStaticTransform
{
__device__
void
Run
(
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
// weight tensor
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input
const
index_t
k0
=
p_in_global
[
get_thread_local_1d_id
()];
const
index_t
n0
=
p_in_global
[
get_thread_local_1d_id
()];
auto
coord
=
typename
TensorCoordinate
<
decltype
(
in_gemmk_gemmn_global_desc
)
>::
type
(
k0
,
n0
);
#pragma unroll 1
for
(
index_t
k
=
0
;
k
<
100
;
++
k
)
{
coord
+=
make_multi_index
(
8
,
0
);
Float
value
=
1
;
transfer_data
<
Float
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
InMemoryDataOperation
::
Set
,
1
,
1
>
(
&
value
,
0
,
true
,
1
,
p_in_global
,
coord
.
GetOffset
(),
coord
.
IsOffsetValidAssumingUpperIndexIsValid
(),
in_gemmk_gemmn_global_desc
.
GetElementSpace
());
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
f6ec737c
...
@@ -95,28 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -95,28 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
}
#if 0
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
...
@@ -352,5 +330,366 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -352,5 +330,366 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
};
};
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
ThreadMatrixC
{}.
GetLength
(
I0
)
==
GetThreadMatrixCLengths
()[
I0
]
&&
ThreadMatrixC
{}.
GetLength
(
I1
)
==
GetThreadMatrixCLengths
()[
I1
],
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
col
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
static_assert
(
M
==
128
,
"wrong!"
);
static_assert
(
MPerThreadSubC
==
4
,
"wrong!"
);
static_assert
(
MRepeat
==
2
,
"wrong!"
);
static_assert
(
NRepeat
==
2
,
"wrong!"
);
static_assert
(
NPerThreadSubC
==
4
,
"wrong!"
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1ThreadCluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1ThreadCluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0ThreadCluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0ThreadCluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
[
I0
];
constexpr
index_t
MPerThread
=
c_thread_mtx
[
I0
];
constexpr
index_t
NPerThread
=
c_thread_mtx
[
I1
];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
}
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{}));
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
b_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
c_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_a_thread
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
MPerThread
=
ThreadMatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
NPerThread
=
ThreadMatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
else
{
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
f6ec737c
...
@@ -130,12 +130,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -130,12 +130,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_
multi_index
(
KPerBlock
,
MPerBlock
),
max_lds_align
);
make_
tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
Number
<
max_lds_align
>
{}
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_
multi_index
(
KPerBlock
,
NPerBlock
),
max_lds_align
);
make_
tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
Number
<
max_lds_align
>
{}
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -201,18 +201,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -201,18 +201,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
constexpr
index_t
a_k_m_block_mtx_stride
=
a_k_m_block_desc
.
CalculateOffset
(
make_multi_index
(
1
,
0
))
-
a_k_m_block_desc
.
CalculateOffset
(
make_multi_index
(
0
,
0
));
constexpr
index_t
b_k_n_block_mtx_stride
=
b_k_n_block_desc
.
CalculateOffset
(
make_multi_index
(
1
,
0
))
-
b_k_n_block_desc
.
CalculateOffset
(
make_multi_index
(
0
,
0
));
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
Number
<
a_k_m_block_mtx_stride
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
Number
<
b_k_n_block_mtx_stride
>
{});
// sanity check
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
...
@@ -223,23 +211,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -223,23 +211,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_
mtx_
desc
=
make_
ConstantMatrixD
escriptor_packed
(
constexpr
auto
c_m0m1_n0n1_thread_desc
=
make_
dynamic_naive_tensor_d
escriptor_packed
_v2
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{});
make_tuple
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{})
)
;
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
const
auto
blockwise_gemm
=
BlockSize
,
BlockwiseGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
decltype
(
a_k_m_block_
mtx_
desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_
mtx_
desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_
mtx_
desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
MPerThread
,
NPerThread
,
NPerThread
,
KPerThread
,
KPerThread
,
MLevel0Cluster
,
MLevel0Cluster
,
NLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
MPerThread
,
NPerThread
>
{};
NPerThread
>
{};
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
constexpr
index_t
a_block_space_size
=
...
@@ -252,10 +240,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -252,10 +240,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_
mtx_
desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpace
Size
()];
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_
mtx_
desc
,
p_c_thread
);
threadwise_matrix_set_zero
_v2
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
@@ -422,7 +410,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -422,7 +410,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
AddressSpace
::
Global
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
,
true
>
(
c_m0_m1_n0_n1_global_desc
,
true
>
(
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
m_thread_data_on_global
%
M1
,
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
f6ec737c
...
@@ -44,12 +44,12 @@ template <typename SrcData,
...
@@ -44,12 +44,12 @@ template <typename SrcData,
AddressSpace
DstAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
>
bool
DstResetCoordinateAfterRun
>
struct
ThreadwiseDynamicTensorSliceTransfer_v1r3
struct
ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
DstCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
DstDesc
{},
Index
{}));
...
@@ -61,11 +61,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -61,11 +61,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
}
}
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v1r3
()
:
ThreadwiseDynamicTensorSliceTransfer_v1r3
(
DstDesc
{},
make_zero_multi_index
<
nDim
>
())
{
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
dst_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
...
@@ -297,7 +292,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -297,7 +292,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
return
forward_sweep
;
return
forward_sweep
;
}();
}();
// calculate dst data index after last iteration in Run
Write
(), if it has not being reset by
// calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite()
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
Index
ordered_idx
;
...
@@ -328,7 +323,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -328,7 +323,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
const
Index
&
dst_slice_origin_step_idx
)
{
{
// if dst coord was not reset by Run
Write
(), then need to adjust the step here
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
...
@@ -344,6 +339,319 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -344,6 +339,319 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord
dst_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
};
// namespace ck
};
// namespace ck
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
>
struct
ThreadwiseDynamicTensorSliceTransfer_v2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoordIterator
=
decltype
(
make_dynamic_tensor_coordinate_iterator
(
SrcDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
:
src_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
{
}
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
template
<
typename
SrcIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
DstData
*
p_dst
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// Comments: dst_desc is constexpr
constexpr
auto
dst_desc
=
remove_cv_t
<
remove_reference_t
<
DstDesc
>>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// make forward iterators
const
auto
src_forward_iterators
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
forward_step
,
src_iterator_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward iterators
const
auto
src_backward_iterators
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
backward_step
,
src_iterator_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
forward_sweep
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_idx
[
i
]
:
ordered_access_lengths
[
i
]
-
1
-
ordered_access_idx
[
i
];
});
auto
src_data_idx
=
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
src_scalar_per_access
;
return
src_data_idx
;
}();
// copy data
// hardcoding for buffer_store
// TODO refactor transfer_data() to encapsulate this
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for ds_read"
);
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
MemoryType
;
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
{
src_vector
.
Vector
()
=
amd_buffer_load
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
true
,
src_desc
.
GetElementSpaceSize
());
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_vector
.
Vector
()
=
is_valid
?
src_vector
.
Vector
()
:
src_vector_t
{
0
};
}
else
{
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_vector
.
Vector
()
=
is_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
// this is hardcoded for dst that has compile-time tensor descriptor
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// assume dst_slice_origin_idx is 0
// TODO: support non-zero dst_slice_oring_idx
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
src_data_idx
+
i
*
src_scalar_step_in_vector
);
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim
;
}
();
// move
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
src_forward_iterators
[
dim_access_order
[
i
]]);
}
else
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
src_backward_iterators
[
dim_access_order
[
i
]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_iterator
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
GetSrcCoordinateResetStep
());
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
src_reset_iterator
);
}
}
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
DstData
*
p_dst
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_iterator_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
p_src
,
p_dst
,
src_iterator_hacks
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
forward_sweep
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep
;
}();
// calculate src data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
auto
src_data_idx
=
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
src_scalar_per_access
;
return
src_data_idx
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step
;
}();
return
reset_src_data_step
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
adjusted_step
);
}
private:
SrcCoord
src_slice_origin_coord_
;
};
// namespace ck
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 1. It does not keep reference to tensor descriptor
...
@@ -398,12 +706,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -398,12 +706,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
"wrong!"
);
"wrong!"
);
}
}
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v3
()
:
ThreadwiseDynamicTensorSliceTransfer_v3
(
SrcDesc
{},
make_zero_multi_index
<
nDim
>
(),
DstDesc
{},
make_zero_multi_index
<
nDim
>
())
{
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
src_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
...
@@ -512,7 +814,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -512,7 +814,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
using
S
rc
V
ector
Type
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
MemoryType
;
using
s
rc
_v
ector
_t
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
MemoryType
;
#if 1
#if 1
src_vector
.
Vector
()
=
amd_buffer_load
<
SrcData
,
SrcScalarPerVector
>
(
src_vector
.
Vector
()
=
amd_buffer_load
<
SrcData
,
SrcScalarPerVector
>
(
...
@@ -521,7 +823,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -521,7 +823,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_desc
,
src_slice_origin_coord_
);
src_vector
.
Vector
()
=
is_valid
?
src_vector
.
Vector
()
:
S
rc
V
ector
Type
{
0
};
src_vector
.
Vector
()
=
is_valid
?
src_vector
.
Vector
()
:
s
rc
_v
ector
_t
{
0
};
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
buffer_offset
=
constexpr
index_t
buffer_offset
=
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
f6ec737c
...
@@ -161,5 +161,160 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -161,5 +161,160 @@ struct ThreadwiseGemmTransANormalBNormalC
}
}
};
};
template
<
typename
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero_v2
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
Matrix
{}.
GetLength
(
I0
);
constexpr
auto
N
=
Matrix
{}.
GetLength
(
I1
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
offset
=
Matrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
}
template
<
typename
SrcMatrix
,
typename
DstMatrix
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v2
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcMatrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstMatrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C += transpose(A) * B
// Element of matrix can be vectorized data
template
<
typename
MatrixA
,
typename
MatrixB
,
typename
MatrixC
>
struct
ThreadwiseGemm_km_kn_mn_v1
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
MatrixC
{}[
I0
];
constexpr
index_t
N
=
MatrixC
{}[
I1
];
constexpr
index_t
K
=
MatrixA
{}[
I0
];
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
const
index_t
a_offset
=
MatrixA
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
// A is transposed
const
index_t
b_offset
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
const
index_t
c_offset
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
M
=
MatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
N
=
MatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
K
=
MatrixA
{}.
GetLength
(
I0
);
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
MatrixA
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
if
constexpr
(
N
==
2
)
{
constexpr
auto
b_offset_0
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
c_offset_0
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
]);
}
else
if
constexpr
(
N
==
4
)
{
constexpr
auto
b_offset_0
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
}
});
});
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr
bool
has_amd_asm
=
is_same
<
FloatC
,
float
>
{}
&&
((
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{})
||
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
if
constexpr
(
has_amd_asm
)
{
Run_amd_asm
(
p_a
,
p_b
,
p_c
);
}
else
{
Run_source
(
p_a
,
p_b
,
p_c
);
}
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/math.hpp
View file @
f6ec737c
...
@@ -114,8 +114,8 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
...
@@ -114,8 +114,8 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
}
}
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
template
<
typename
X
,
typename
Y
>
template
<
typename
T
>
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Y
y
)
__host__
__device__
constexpr
T
gcd
(
T
x
,
T
y
)
{
{
if
(
x
==
y
||
x
==
0
)
if
(
x
==
y
||
x
==
0
)
{
{
...
@@ -135,6 +135,14 @@ __host__ __device__ constexpr auto gcd(X x, Y y)
...
@@ -135,6 +135,14 @@ __host__ __device__ constexpr auto gcd(X x, Y y)
}
}
}
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
gcd
(
Number
<
X
>
,
Number
<
Y
>
)
{
constexpr
auto
r
=
gcd
(
X
,
Y
);
return
Number
<
r
>
{};
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
{
...
...
driver/include/device_dummy_dynamic_transform.hpp
deleted
100644 → 0
View file @
a7c587ee
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_dynamic_transform.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dummy_dynamic_transform
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
const
auto
in_nchw_desc
=
make_dynamic_naive_tensor_descriptor
<
4
>
(
to_multi_index
(
InDesc
::
GetLengths
()),
to_multi_index
(
InDesc
::
GetStrides
()));
const
auto
wei_kcyx_desc
=
make_dynamic_naive_tensor_descriptor
<
4
>
(
to_multi_index
(
WeiDesc
::
GetLengths
()),
to_multi_index
(
WeiDesc
::
GetStrides
()));
const
auto
out_nkhw_desc
=
make_dynamic_naive_tensor_descriptor
<
4
>
(
to_multi_index
(
OutDesc
::
GetLengths
()),
to_multi_index
(
OutDesc
::
GetStrides
()));
const
auto
conv_strides
=
to_multi_index
(
ConvStrides
{});
const
auto
conv_dilations
=
to_multi_index
(
ConvDilations
{});
const
auto
in_left_pads
=
to_multi_index
(
InLeftPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
const
auto
tensor_descs
=
map_convolution_into_gemm_fwd_v4r4
(
wei_kcyx_desc
,
in_nchw_desc
,
out_nkhw_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
const
auto
in_gemmk_gemmn_gemmkpack_global_desc
=
tensor_descs
.
At
(
Number
<
0
>
{});
// test on cpu
{
auto
in_gemmk_gemmn_gemmkpack_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
0
));
const
auto
in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1
=
make_dynamic_tensor_coordinate_iterator
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
1
));
print_array_v2
(
"do_tansforms 0 0 1: "
,
in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
printf
(
"iter %d
\n
"
,
iter
);
print_array_v2
(
"idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetIndex
());
print_array_v2
(
"hidden idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_gemmkpack_coord
.
GetOffset
());
printf
(
"
\n
"
);
move_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
in_gemmk_gemmn_gemmkpack_coord
,
in_gemmk_gemmn_gemmkpack_coord_iterator_0_0_1
);
}
}
{
auto
in_gemmk_gemmn_gemmkpack_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
0
));
const
auto
in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0
=
make_dynamic_tensor_coordinate_iterator
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
1
,
0
));
print_array_v2
(
"do_tansforms 0 1 0: "
,
in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
printf
(
"iter %d
\n
"
,
iter
);
print_array_v2
(
"idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetIndex
());
print_array_v2
(
"hidden idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_gemmkpack_coord
.
GetOffset
());
printf
(
"
\n
"
);
move_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
in_gemmk_gemmn_gemmkpack_coord
,
in_gemmk_gemmn_gemmkpack_coord_iterator_0_1_0
);
}
}
{
auto
in_gemmk_gemmn_gemmkpack_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
0
));
const
auto
in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0
=
make_dynamic_tensor_coordinate_iterator
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
1
,
0
,
0
));
print_array_v2
(
"do_tansforms 1 0 0: "
,
in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
printf
(
"iter %d
\n
"
,
iter
);
print_array_v2
(
"idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetIndex
());
print_array_v2
(
"hidden idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_gemmkpack_coord
.
GetOffset
());
printf
(
"
\n
"
);
move_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
in_gemmk_gemmn_gemmkpack_coord
,
in_gemmk_gemmn_gemmkpack_coord_iterator_1_0_0
);
}
}
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GridSize
=
1
;
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
#if 0
launch_kernel(run_gridwise_operation<DummyDynamicTransform_1<BlockSize>,
index_t* const,
float* const,
float* const,
const decltype(wei_kcyx_desc),
const decltype(in_nchw_desc),
const decltype(out_nkhw_desc),
const MultiIndex<2>,
const MultiIndex<2>,
const MultiIndex<2>,
const MultiIndex<2>>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<index_t*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<float*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<float*>(out_nkhw_device_buf.GetDeviceBuffer()),
wei_kcyx_desc,
in_nchw_desc,
out_nkhw_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
#else
launch_kernel
(
run_gridwise_operation
<
DummyDynamicTransform_fwd_v4r4
<
BlockSize
>
,
index_t
*
const
,
float
*
const
,
float
*
const
,
const
decltype
(
in_gemmk_gemmn_gemmkpack_global_desc
)
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
index_t
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()),
in_gemmk_gemmn_gemmkpack_global_desc
);
#endif
}
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/include/device_dummy_dynamic_transform_v1.hpp
deleted
100644 → 0
View file @
a7c587ee
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_dynamic_transform_v1.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dummy_dynamic_transform_v1
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
const
auto
in_nchw_desc
=
make_dynamic_naive_tensor_descriptor_v1
(
to_multi_index
(
InDesc
::
GetLengths
()),
to_multi_index
(
InDesc
::
GetStrides
()));
const
auto
wei_kcyx_desc
=
make_dynamic_naive_tensor_descriptor_v1
(
to_multi_index
(
WeiDesc
::
GetLengths
()),
to_multi_index
(
WeiDesc
::
GetStrides
()));
const
auto
out_nkhw_desc
=
make_dynamic_naive_tensor_descriptor_v1
(
to_multi_index
(
OutDesc
::
GetLengths
()),
to_multi_index
(
OutDesc
::
GetStrides
()));
const
auto
conv_strides
=
to_multi_index
(
ConvStrides
{});
const
auto
conv_dilations
=
to_multi_index
(
ConvDilations
{});
const
auto
in_left_pads
=
to_multi_index
(
InLeftPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
{
const
auto
tensor_descs
=
map_convolution_into_gemm_v1
(
wei_kcyx_desc
,
in_nchw_desc
,
out_nkhw_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
const
auto
in_gemmk_gemmn_global_desc
=
tensor_descs
.
At
(
Number
<
0
>
{});
auto
in_gemmk_gemmn_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_global_desc
,
make_multi_index
(
0
,
0
));
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
constexpr
auto
gemmk1_gemmn0
=
make_multi_index
(
1
,
0
);
printf
(
"iter %d
\n
"
,
iter
);
print_array
(
"idx0: "
,
in_gemmk_gemmn_coord
.
GetIndex
());
print_array
(
"idx1: "
,
in_gemmk_gemmn_coord
.
GetLowerCoordinate
().
GetIndex
());
print_array
(
"idx2: "
,
in_gemmk_gemmn_coord
.
GetLowerCoordinate
().
GetLowerCoordinate
().
GetIndex
());
print_array
(
"idx3: "
,
in_gemmk_gemmn_coord
.
GetLowerCoordinate
()
.
GetLowerCoordinate
()
.
GetLowerCoordinate
()
.
GetIndex
());
print_array
(
"idx4: "
,
in_gemmk_gemmn_coord
.
GetLowerCoordinate
()
.
GetLowerCoordinate
()
.
GetLowerCoordinate
()
.
GetLowerCoordinate
()
.
GetIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_coord
.
GetOffset
());
printf
(
"
\n
"
);
in_gemmk_gemmn_coord
+=
gemmk1_gemmn0
;
}
}
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GridSize
=
1
;
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
using
dummy_transform
=
DummyDynamicTransform_v1
<
BlockSize
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
dummy_transform
,
index_t
*
const
,
float
*
const
,
float
*
const
,
const
DynamicNativeTensorDescriptor_v1
<
4
>
,
const
DynamicNativeTensorDescriptor_v1
<
4
>
,
const
DynamicNativeTensorDescriptor_v1
<
4
>
,
const
MultiIndex
<
2
>
,
const
MultiIndex
<
2
>
,
const
MultiIndex
<
2
>
,
const
MultiIndex
<
2
>>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
index_t
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()),
wei_kcyx_desc
,
in_nchw_desc
,
out_nkhw_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
}
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/include/device_dummy_static_transform.hpp
deleted
100644 → 0
View file @
a7c587ee
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_static_transform.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dummy_static_transform
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GridSize
=
1
;
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
using
dummy_transform
=
DummyStaticTransform
<
GridSize
,
BlockSize
,
float
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
dummy_transform
,
float
*
const
__restrict__
,
float
*
const
__restrict__
,
float
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
float
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
f6ec737c
...
@@ -14,9 +14,6 @@
...
@@ -14,9 +14,6 @@
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp"
#include "device_dummy_dynamic_transform.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -616,42 +613,6 @@ int main(int argc, char* argv[])
...
@@ -616,42 +613,6 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 0
device_dummy_static_transform
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 0
device_dummy_dynamic_transform_v1
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 1
device_dummy_dynamic_transform
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#endif
#endif
if
(
do_verification
)
if
(
do_verification
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment