Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b15b21a
Commit
1b15b21a
authored
Dec 13, 2021
by
Chao Liu
Browse files
update static_tensor for dealing with invalid element
parent
2fd5e6ae
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
240 additions
and
256 deletions
+240
-256
composable_kernel/include/tensor_description/static_tensor.hpp
...sable_kernel/include/tensor_description/static_tensor.hpp
+20
-15
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+10
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+144
-77
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
+6
-1
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+1
-1
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+5
-1
composable_kernel/include/utility/utility.hpp
composable_kernel/include/utility/utility.hpp
+4
-0
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
...e/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
+50
-10
device_operation/include/element_wise_operation.hpp
device_operation/include/element_wise_operation.hpp
+0
-151
No files found.
composable_kernel/include/tensor_description/static_tensor.hpp
View file @
1b15b21a
#ifndef CK_STATIC_TENSOR_HPP
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace
ck
{
namespace
ck
{
// StaticTensor for Scalar
// StaticTensor for Scalar
...
@@ -17,10 +15,10 @@ struct StaticTensor
...
@@ -17,10 +15,10 @@ struct StaticTensor
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_
scalar_
value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
:
invalid_element_
scalar_
value_
{
invalid_element_value
}
{
{
}
}
...
@@ -44,11 +42,11 @@ struct StaticTensor
...
@@ -44,11 +42,11 @@ struct StaticTensor
{
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
{
return
T
{
0
}
;
return
zero_scalar_value_
;
}
}
else
else
{
{
return
invalid_element_value_
;
return
invalid_element_
scalar_
value_
;
}
}
}
}
}
}
...
@@ -71,12 +69,14 @@ struct StaticTensor
...
@@ -71,12 +69,14 @@ struct StaticTensor
}
}
else
else
{
{
return
ignore
;
return
ignore
d_element_scalar_
;
}
}
}
}
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
T
zero_scalar_value_
=
T
{
0
};
const
T
invalid_element_scalar_value_
;
T
ignored_element_scalar_
;
};
};
// StaticTensor for vector
// StaticTensor for vector
...
@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_scalar_value_
{
0
}
{
}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
:
invalid_element_
scalar_
value_
{
invalid_element_value
}
{
{
}
}
...
@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
{
return
S
{
0
}
;
return
zero_scalar_value_
;
}
}
else
else
{
{
return
invalid_element_value_
;
return
invalid_element_
scalar_
value_
;
}
}
}
}
}
}
...
@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
}
}
else
else
{
{
return
ignore
;
return
ignore
d_element_scalar_
;
}
}
}
}
...
@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else
else
{
{
// TODO: is this right way to initialize a vector?
// TODO: is this right way to initialize a vector?
return
X
{
invalid_element_value_
};
return
X
{
invalid_element_
scalar_
value_
};
}
}
}
}
}
}
...
@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
}
}
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
S
invalid_element_value_
=
S
{
0
};
static
constexpr
S
zero_scalar_value_
=
S
{
0
};
const
S
invalid_element_scalar_value_
=
S
{
0
};
S
ignored_element_scalar_
;
};
};
template
<
AddressSpaceEnum_t
AddressSpace
,
template
<
AddressSpaceEnum_t
AddressSpace
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
1b15b21a
...
@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4
...
@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
RunRead
(
src_desc
,
src_buf
);
RunWrite
(
dst_desc
,
dst_buf
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
1b15b21a
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
#define DEBUG_USE_C_SHUFFLE 0
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -17,7 +19,11 @@ template <typename GridwiseGemm,
...
@@ -17,7 +19,11 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
#if !DEBUG_USE_C_SHUFFLE
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
#else
typename
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
#endif
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
...
@@ -33,24 +39,30 @@ __global__ void
...
@@ -33,24 +39,30 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
#if !DEBUG_USE_C_SHUFFLE
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
#else
const
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#endif
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
_block
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#endif
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
...
@@ -220,6 +232,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -220,6 +232,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
has_main_k0_block_loop
;
return
has_main_k0_block_loop
;
}
}
#if !DEBUG_USE_C_SHUFFLE
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
...
@@ -269,6 +282,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -269,6 +282,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
}
}
#else
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXdl
);
const
index_t
MBlock
=
M
/
(
MWave
*
MPerXdl
*
MRepeat
);
const
index_t
NBlock
=
N
/
(
NWave
*
NPerXdl
*
NRepeat
);
const
auto
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MRepeat
>
{},
Number
<
MWave
*
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NRepeat
>
{},
Number
<
NWave
*
NPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
;
}
#endif
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -305,20 +345,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -305,20 +345,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
#if !DEBUG_USE_C_SHUFFLE
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
#else
using
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
#endif
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared
_block
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
#if !DEBUG_USE_C_SHUFFLE
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
#else
const
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
&
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#endif
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
@@ -328,8 +379,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -328,8 +379,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
#if !DEBUG_USE_C_SHUFFLE
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
#else
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.
GetElementSpaceSize
());
#endif
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
@@ -459,8 +518,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -459,8 +518,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
#if !DEBUG_USE_C_SHUFFLE
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
FloatAB
*
p_a_block
=
static_cast
<
FloatAB
*>
(
p_shared
);
FloatAB
*
p_b_block
=
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
#else
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
#endif
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
@@ -474,11 +548,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -474,11 +548,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
...
@@ -530,7 +599,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -530,7 +599,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
#if
1
#if
!DEBUG_USE_C_SHUFFLE
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
@@ -620,7 +689,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -620,7 +689,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
index_t
MPerBlock_CCopy
=
MWave
*
MPerXdl
;
constexpr
index_t
MPerBlock_CCopy
=
MWave
*
MPerXdl
;
constexpr
index_t
NPerBlock_CCopy
=
NWave
*
NPerXdl
;
constexpr
index_t
NPerBlock_CCopy
=
NWave
*
NPerXdl
;
// hacky
//
TODO:
hacky
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
@@ -636,29 +705,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -636,29 +705,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mwavemperxdl_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
Number
<
MPerBlock_CCopy
>
{},
Number
<
NPerBlock_CCopy
>
{});
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mwavemperxdl_nwavenperxdl
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
Number
<
MWave
>
{},
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
I1
,
Number
<
NWave
>
{},
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
mdke_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
...
@@ -679,13 +733,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -679,13 +733,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index
(
n_thread_data_on_block
));
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
// VGPR to LDS
auto
c_thread_copy_vgpr
2
lds
=
auto
c_thread_copy_vgpr
_to_
lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatAcc
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
I1
,
I1
,
I1
,
I1
,
M2
,
M3
,
M4
,
N2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
7
,
1
,
1
,
...
@@ -695,37 +749,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -695,37 +749,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
make_multi_index
(
0
,
0
,
0
,
m_thread_data_on_
grid
_idx
[
I1
],
m_thread_data_on_
block
_idx
[
I1
],
n_thread_data_on_
grid
_idx
[
I1
],
n_thread_data_on_
block
_idx
[
I1
],
m_thread_data_on_
grid
_idx
[
I2
],
m_thread_data_on_
block
_idx
[
I2
],
m_thread_data_on_
grid
_idx
[
I3
],
m_thread_data_on_
block
_idx
[
I3
],
m_thread_data_on_
grid
_idx
[
I4
],
m_thread_data_on_
block
_idx
[
I4
],
n_thread_data_on_
grid
_idx
[
I2
]),
n_thread_data_on_
block
_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
//
hardcoded
//
TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr
index_t
MThread_CCopy
=
16
;
constexpr
index_t
MThread_CCopy
=
16
;
constexpr
index_t
NThread_CCopy
=
16
;
constexpr
index_t
NThread_CCopy
=
16
;
constexpr
index_t
MPerThread_CCopy
=
MPerBlock_CCopy
/
MThread_CCopy
;
constexpr
index_t
MPerThread_CCopy
=
MPerBlock_CCopy
/
MThread_CCopy
;
constexpr
index_t
NPerThread_CCopy
=
NPerBlock_CCopy
/
NThread_CCopy
;
constexpr
index_t
NPerThread_CCopy
=
NPerBlock_CCopy
/
NThread_CCopy
;
constexpr
auto
c_block_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl
=
constexpr
auto
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MPerBlock_CCopy
>
{},
I1
,
I1
,
Number
<
NPerBlock_CCopy
>
{});
I1
,
I1
,
Number
<
MPerBlock_CCopy
>
{},
I1
,
I1
,
Number
<
NPerBlock_CCopy
>
{}));
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAcc
*>
(
p_shared
),
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.
GetElementSpaceSize
());
auto
c_block_copy
=
BlockwiseTensorSliceTransfer_v4
<
auto
c_block_copy
_lds_to_global
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
// index_t BlockSize,
BlockSize
,
// index_t BlockSize,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
// SrcElementwiseOperation,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
// SrcElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
1
,
MPerBlock_CCopy
,
1
,
1
,
NPerBlock_CCopy
>
,
// BlockSliceLengths,
Sequence
<
1
,
1
,
MPerBlock_CCopy
,
1
,
1
,
NPerBlock_CCopy
>
,
// BlockSliceLengths,
Sequence
<
1
,
1
,
MPerThread_CCopy
,
1
,
1
,
NPerThread_CCopy
>
,
// ThreadSliceLengths,
Sequence
<
1
,
1
,
MPerThread_CCopy
,
1
,
1
,
NPerThread_CCopy
>
,
// ThreadSliceLengths,
Sequence
<
1
,
1
,
M
Per
Thread
,
1
,
1
,
N
Per
Thread
>
,
// typename
ThreadClusterLengths,
Sequence
<
1
,
1
,
MThread
_CCopy
,
1
,
1
,
NThread
_CCopy
>
,
//
ThreadClusterLengths,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
FloatAcc
,
// typename SrcData,
FloatAcc
,
// typename SrcData,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mrepeat_mwave
MP
er
X
dl_nblock_nrepeat_nwave
NP
er
X
dl
),
decltype
(
c_block_desc_mblock_mrepeat_mwave
mp
er
x
dl_nblock_nrepeat_nwave
np
er
x
dl
),
decltype
(
c_g
lobal
_desc_mblock_mrepeat_mwave
MP
er
X
dl_nblock_nrepeat_nwave
NP
er
X
dl
),
decltype
(
c_g
rid
_desc_mblock_mrepeat_mwave
mp
er
x
dl_nblock_nrepeat_nwave
np
er
x
dl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DstDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DstDimAccessOrder,
5
,
// index_t SrcVectorDim,
5
,
// index_t SrcVectorDim,
...
@@ -736,12 +795,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -736,12 +795,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1
,
// index_t DstScalarStrideInVector,
1
,
// index_t DstScalarStrideInVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
{
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
c_block_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
block_work_idx
[
I1
],
0
,
0
)
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
}
constexpr
auto
mrepeat_forward_step
=
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
);
constexpr
auto
mrepeat_forward_step
=
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
);
constexpr
auto
nrepeat_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
1
,
0
);
constexpr
auto
nrepeat_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
1
,
0
);
...
@@ -750,37 +808,46 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -750,37 +808,46 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// make sure all ds_read from GEMM is completed
// make sure all ds_read from GEMM is completed
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mrepeat
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mrepeat_iter
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nrepeat
)
{
constexpr
auto
mrepeat
=
mrepeat_iter
;
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nrepeat_iter
)
{
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
constexpr
bool
nrepeat_forward_sweep
=
(
mrepeat
%
2
==
0
);
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
constexpr
index_t
nrepeat_value
=
nrepeat_forward_sweep
?
nrepeat_iter
:
(
NRepeat
-
nrepeat_iter
-
1
);
constexpr
auto
nrepeat
=
Number
<
nrepeat_value
>
{};
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
mrepeat
,
nrepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure ds_write from c_thread_copy_vgpr_to_lds is completed
block_sync_lds
();
block_sync_lds
();
// LDS to global
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_mrepeat_mwave
MP
er
X
dl_nblock_nrepeat_nwave
NP
er
X
dl
,
c_block_desc_mblock_mrepeat_mwave
mp
er
x
dl_nblock_nrepeat_nwave
np
er
x
dl
,
c_block_buf
,
c_block_buf
,
c_global_desc_mblock_mrepeat_mwaveMPerXdl_nblock_nrepeat_nwaveNPerXdl
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
c_global_buf
);
c_grid_buf
);
constexpr
bool
nrepeat_forward_sweep
=
mrepeat
%
2
==
0
;
// move on nrepeat dimension
// move on nrepeat dimension
if
constexpr
(
nrepeat_forward_sweep
&&
nrepeat
<
NRepeat
-
1
)
if
constexpr
(
nrepeat_forward_sweep
&&
(
nrepeat
<
NRepeat
-
1
)
)
{
{
c_block_copy
.
MoveDstSliceWindow
(
c_block_copy
_lds_to_global
.
MoveDstSliceWindow
(
c_g
lobal
_desc_mblock_mrepeat_mwave
MP
er
X
dl_nblock_nrepeat_nwave
NP
er
X
dl
,
c_g
rid
_desc_mblock_mrepeat_mwave
mp
er
x
dl_nblock_nrepeat_nwave
np
er
x
dl
,
nrepeat_forward_step
);
nrepeat_forward_step
);
}
}
else
if
constexpr
((
!
nrepeat_forward_sweep
)
&
nrepeat
>
1
)
else
if
constexpr
((
!
nrepeat_forward_sweep
)
&
&
(
nrepeat
>
1
)
)
{
{
c_block_copy
.
MoveDstSliceWindow
(
c_block_copy
_lds_to_global
.
MoveDstSliceWindow
(
c_g
lobal
_desc_mblock_mrepeat_mwave
MP
er
X
dl_nblock_nrepeat_nwave
NP
er
X
dl
,
c_g
rid
_desc_mblock_mrepeat_mwave
mp
er
x
dl_nblock_nrepeat_nwave
np
er
x
dl
,
nrepeat_backward_step
);
nrepeat_backward_step
);
}
}
});
});
...
@@ -789,7 +856,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -789,7 +856,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
if
constexpr
(
mrepeat
<
MRepeat
-
1
)
if
constexpr
(
mrepeat
<
MRepeat
-
1
)
{
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_g
lobal
_desc_mblock_mrepeat_mwave
MP
er
X
dl_nblock_nrepeat_nwave
NP
er
X
dl
,
c_g
rid
_desc_mblock_mrepeat_mwave
mp
er
x
dl_nblock_nrepeat_nwave
np
er
x
dl
,
mrepeat_forward_step
);
mrepeat_forward_step
);
}
}
});
});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
View file @
1b15b21a
...
@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
index_t
tmp
=
ordered_src_access_idx
[
I0
];
// TODO: BUG: should start at 1
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
});
...
@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
// TODO: BUG: should start at 1
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
});
...
@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template
<
typename
DstBuffer
>
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
// TODO: why need remove_cvref_t ?
constexpr
index_t
ntransform_dst
=
remove_cvref_t
<
DstDesc
>::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
...
@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
forward_sweep_
(
I0
)
=
true
;
forward_sweep_
(
I0
)
=
true
;
// TODO: BUG: should start at 1
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
...
@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
// TODO: BUG: should start at 1
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
...
...
composable_kernel/include/utility/common_header.hpp
View file @
1b15b21a
...
@@ -35,8 +35,8 @@
...
@@ -35,8 +35,8 @@
#include "dynamic_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp"
#include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp"
#include "transpose_vectors.hpp"
#include "inner_product.hpp"
#include "inner_product.hpp"
#include "element_wise_operation.hpp"
// TODO: remove this
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
...
...
composable_kernel/include/utility/config.hpp
View file @
1b15b21a
...
@@ -24,12 +24,16 @@
...
@@ -24,12 +24,16 @@
#define CK_MIN_BLOCK_PER_CU 2
#define CK_MIN_BLOCK_PER_CU 2
#endif
#endif
//
buffer resou
rs
e
//
GPU-specific paramete
rs
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
// buffer resourse
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
// wave size
#define CK_GPU_WAVE_SIZE 64
#elif defined(CK_AMD_GPU_GFX1030)
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32
#endif
#endif
// FMA instruction
// FMA instruction
...
...
composable_kernel/include/utility/utility.hpp
View file @
1b15b21a
...
@@ -5,8 +5,12 @@
...
@@ -5,8 +5,12 @@
namespace
ck
{
namespace
ck
{
__device__
constexpr
index_t
get_wave_size
()
{
return
CK_GPU_WAVE_SIZE
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_wave_local_1d_id
()
{
return
threadIdx
.
x
/
get_wave_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
}
// namespace ck
}
// namespace ck
...
...
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
View file @
1b15b21a
...
@@ -29,8 +29,8 @@ template <typename InDataType,
...
@@ -29,8 +29,8 @@ template <typename InDataType,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPerX
DL
,
ck
::
index_t
MPerX
dl
,
ck
::
index_t
NPerX
DL
,
ck
::
index_t
NPerX
dl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
...
@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
MPerX
DL
,
MPerX
dl
,
NPerX
DL
,
NPerX
dl
,
K1
,
K1
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
...
@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
BBlockLdsAddExtraN
>
;
#if !DEBUG_USE_C_SHUFFLE
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
#endif
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
{},
#endif
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
{
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
c_grid_desc_m_n_
);
#endif
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
...
@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
#if !DEBUG_USE_C_SHUFFLE
Block2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
#else
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
;
#endif
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
InElementwiseOperation
in_element_op_
;
InElementwiseOperation
in_element_op_
;
...
@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#else
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
#endif
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
...
@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
#if !DEBUG_USE_C_SHUFFLE
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
#else
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
#endif
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#else
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
#endif
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
...
@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
#if !DEBUG_USE_C_SHUFFLE
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
#else
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
#endif
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
...
device_operation/include/element_wise_operation.hpp
deleted
100644 → 0
View file @
2fd5e6ae
#ifndef ELEMENT_WISE_OPERATION_HPP
#define ELEMENT_WISE_OPERATION_HPP
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
struct
AddRelu
{
template
<
typename
T1
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
}
template
<
typename
T1
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
#endif
}
};
struct
AddReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float
b
=
v1
+
v2
;
float
c
=
(
v0
>
-
v1
)
?
b
+
v0
:
v2
;
return
c
;
#endif
}
};
struct
AddLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif
0
// this spill register
float
a
=
v0
+
v1
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#elif 0
// this use lots of registers (but no spill)
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this spill registers, 89 Tflops
float
a
=
v0
+
v1
;
float
alpha
=
0.1
;
float
b
;
asm
volatile
(
"
\n
\
v_mul_f32_e32 %0, %1, %2
\n
\
"
:
"=v"
(
b
)
:
"s"
(
alpha
),
"v"
(
a
));
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#endif
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment