Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
70eebf22
"vscode:/vscode.git/clone" did not exist on "52be1ccfb9991b655c0d0840a212cfbfc0c77696"
Unverified
Commit
70eebf22
authored
Nov 07, 2023
by
zjing14
Committed by
GitHub
Nov 07, 2023
Browse files
Merge branch 'develop' into grouped_gemm_multi_abd_fixed_nk_example
parents
5608328c
98fd41f5
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
858 additions
and
239 deletions
+858
-239
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
...sor_operation/gpu/device/impl/device_elementwise_impl.hpp
+22
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+23
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+10
-10
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+58
-33
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+19
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+65
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+180
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+91
-46
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+32
-11
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+6
-1
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+98
-49
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+2
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+0
-22
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+184
-8
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+1
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+38
-18
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+21
-20
No files found.
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
70eebf22
...
@@ -145,7 +145,8 @@ template <index_t NumDimM,
...
@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ComputeDataType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceContractionMultipleD_Xdl_CShuffle
struct
DeviceContractionMultipleD_Xdl_CShuffle
:
public
DeviceContractionMultipleD
<
NumDimM
,
:
public
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
,
ComputeDataType
>
{
{
using
DeviceOp
=
DeviceContractionMultipleD_Xdl_CShuffle
;
using
DeviceOp
=
DeviceContractionMultipleD_Xdl_CShuffle
;
...
@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
ComputeDataType
=
ADataType
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
View file @
70eebf22
...
@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
...
@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
{
{
return
std
::
make_unique
<
Invoker
>
();
return
std
::
make_unique
<
Invoker
>
();
};
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceElementwiseImpl<"
;
str
<<
"NumDim_"
<<
NumDim
<<
","
;
str
<<
"MPerThread_"
<<
MPerThread
<<
","
;
str
<<
"InScalarPerVector"
;
static_for
<
0
,
InScalarPerVectorSeq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
str
<<
"_"
<<
InScalarPerVectorSeq
::
At
(
i
).
value
;
});
str
<<
","
;
str
<<
"OutScalarPerVector"
;
static_for
<
0
,
OutScalarPerVectorSeq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
str
<<
"_"
<<
OutScalarPerVectorSeq
::
At
(
i
).
value
;
});
str
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
};
// namespace device
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
70eebf22
...
@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
}
}
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
)
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
70eebf22
...
@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle"
str
<<
"DeviceGemm_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
...
@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<<
" LoopScheduler: "
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
;
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
70eebf22
...
@@ -59,7 +59,8 @@ template <typename ADataType,
...
@@ -59,7 +59,8 @@ template <typename ADataType,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
ComputeType
=
CDataType
,
typename
ComputeType
=
CDataType
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemmSplitK
<
ALayout
,
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
...
@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
// TODO: should be exposed as Tparams.
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
LoopScheduler
LoopSched
=
make_default_loop_scheduler
();
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
BlockSize
,
...
@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
MPadded_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
K0
Padded
_
,
index_t
k_batch_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
BElementwiseOperation
b_element_op_
,
...
@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_
,
MPadded_
,
NPadded_
,
NPadded_
,
KPadded_
,
KPadded_
,
K0_
,
K0
Padded
_
,
k_batch_
),
k_batch_
),
a_element_op
(
a_element_op_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
b_element_op
(
b_element_op_
),
...
@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
K0
=
karg
.
K0
;
const
auto
K0
Padded
=
karg
.
K0
Padded
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
Padded
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
Padded
(
K
,
KBatch
),
KBatch
,
KBatch
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
Padded
(
K
,
KBatch
),
KBatch
,
KBatch
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
}
// polymorphic
// polymorphic
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
str
<<
GridwiseGemm
::
GetTypeString
()
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
70eebf22
...
@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
K
,
K_BATCH
);
const
index_t
k0
_padded
=
GridwiseGemm
::
CalculateK0
Padded
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
...
@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
m_padded
,
m_padded
,
n_padded
,
n_padded
,
k_padded
,
k_padded
,
k0
,
k0
_padded
,
K_BATCH
};
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
gemm_kernel_args_
.
emplace_back
(
...
@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
_padded
=
GridwiseGemm
::
CalculateK0
Padded
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
...
@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
K0
Padded
=
k0
_padded
;
karg
.
k_batch
=
K_BATCH
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
Padded
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
...
@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
K0
=
karg
.
K0
;
K0
=
karg
.
K0
Padded
;
bool
not_all_have_main_k0_block_loop_same
=
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
70eebf22
...
@@ -15,15 +15,18 @@
...
@@ -15,15 +15,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
// Image to column for input layout NDHWC:
// Image to column:
// input : input image [N, Di, Hi, Wi, C]
// input : input image [G, N, Di, Hi, Wi, C]
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C]
// input : input image [N, Di, Hi, Wi, G, C]
// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C]
template
<
index_t
NDimSpatial
,
template
<
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
ImageLayout
,
typename
InputDataType
,
typename
InputDataType
,
...
@@ -41,6 +44,14 @@ struct DeviceImageToColumnImpl
...
@@ -41,6 +44,14 @@ struct DeviceImageToColumnImpl
OutputDataType
,
OutputDataType
,
conv_tensor_rearrange_op
::
ImageToColumn
>
conv_tensor_rearrange_op
::
ImageToColumn
>
{
{
static
constexpr
bool
is_NSpatialGC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
;
static
constexpr
bool
is_GNSpatialC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -109,7 +120,7 @@ struct DeviceImageToColumnImpl
...
@@ -109,7 +120,7 @@ struct DeviceImageToColumnImpl
const
ck
::
index_t
C
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
)
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
)
{
{
const
index_t
NDoHoWo
=
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
N
*
ck
::
accumulate_n
<
index_t
>
(
...
@@ -117,11 +128,10 @@ struct DeviceImageToColumnImpl
...
@@ -117,11 +128,10 @@ struct DeviceImageToColumnImpl
const
index_t
CZYX
=
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
C
*
ck
::
accumulate_n
<
index_t
>
(
filter_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
filter_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
CZYX
),
make_tuple
(
gemm_m_k_strides
[
I0
],
gemm_m_k_strides
[
I1
]));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_mraw_kraw
);
const
auto
desc_mraw_kraw
=
make_naive_tensor_descriptor
(
return
desc_m_k
;
make_tuple
(
NDoHoWo
,
CZYX
),
make_tuple
(
gemm_g_m_k_strides
[
I1
],
gemm_g_m_k_strides
[
I2
]));
return
matrix_padder
.
PadADescriptor_M_K
(
desc_mraw_kraw
);
}
}
using
InputGridDesc
=
using
InputGridDesc
=
...
@@ -132,34 +142,38 @@ struct DeviceImageToColumnImpl
...
@@ -132,34 +142,38 @@ struct DeviceImageToColumnImpl
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
OutputGridDesc
{}))
>
;
OutputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
using
GridwiseTensorRearrangeKernel
=
InputDataType
,
GridwiseTensorRearrange
<
InputGridDesc
,
OutputGridDesc
,
InputDataType
,
OutputDataType
,
OutputGridDesc
,
BlockSize
,
OutputDataType
,
MPerBlock
,
BlockSize
,
KPerBlock
,
MPerBlock
,
ThreadClusterLengths
,
KPerBlock
,
ScalarPerVector
,
ThreadClusterLengths
,
InMemoryDataOperationEnum
::
Set
,
ScalarPerVector
,
Block2ETileMap
>
;
InMemoryDataOperationEnum
::
Set
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
void
*
p_in
,
// input image
Argument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// gemm form
void
*
p_out
,
// gemm form
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
:
C_
(
C
),
:
G_
(
G
),
C_
(
C
),
X_
(
filter_spatial_lengths
[
NDimSpatial
-
I1
]),
X_
(
filter_spatial_lengths
[
NDimSpatial
-
I1
]),
p_in_
{
static_cast
<
const
InputDataType
*>
(
p_in
)},
p_in_
{
static_cast
<
const
InputDataType
*>
(
p_in
)},
p_out_
{
static_cast
<
OutputDataType
*>
(
p_out
)},
p_out_
{
static_cast
<
OutputDataType
*>
(
p_out
)},
...
@@ -176,14 +190,16 @@ struct DeviceImageToColumnImpl
...
@@ -176,14 +190,16 @@ struct DeviceImageToColumnImpl
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
image_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
out_grid_desc_m_k_
=
MakeOutDescriptor_M_K
(
out_grid_desc_m_k_
=
MakeOutDescriptor_M_K
(
N
,
C
,
filter_spatial_lengths
,
output_spatial_lengths
,
gemm_m_k_strides
);
N
,
C
,
filter_spatial_lengths
,
output_spatial_lengths
,
gemm_g_m_k_strides
);
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
image_g_n_c_wis_strides
[
I0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
gemm_g_m_k_strides
[
I0
];
}
}
void
Print
()
const
void
Print
()
const
...
@@ -192,6 +208,7 @@ struct DeviceImageToColumnImpl
...
@@ -192,6 +208,7 @@ struct DeviceImageToColumnImpl
std
::
cout
<<
out_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
out_grid_desc_m_k_
<<
std
::
endl
;
}
}
const
ck
::
index_t
G_
;
const
ck
::
index_t
C_
;
const
ck
::
index_t
C_
;
const
ck
::
index_t
X_
;
const
ck
::
index_t
X_
;
...
@@ -206,6 +223,8 @@ struct DeviceImageToColumnImpl
...
@@ -206,6 +223,8 @@ struct DeviceImageToColumnImpl
InputGridDesc
in_grid_desc_m_k_
;
InputGridDesc
in_grid_desc_m_k_
;
OutputGridDesc
out_grid_desc_m_k_
;
OutputGridDesc
out_grid_desc_m_k_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
@@ -220,12 +239,14 @@ struct DeviceImageToColumnImpl
...
@@ -220,12 +239,14 @@ struct DeviceImageToColumnImpl
const
auto
block_2_tile_map
=
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
arg
.
out_grid_desc_m_k_
);
arg
.
out_grid_desc_m_k_
);
const
index_t
grid_size
=
block_2_tile_map
.
CalculateGridSize
(
arg
.
out_grid_desc_m_k_
);
const
index_t
grid_size
=
const
auto
kernel
=
kernel_tensor_rearrange
<
InputGridDesc
,
block_2_tile_map
.
CalculateGridSize
(
arg
.
out_grid_desc_m_k_
)
*
arg
.
G_
;
const
auto
kernel
=
kernel_tensor_rearrange
<
InputGridDesc
,
InputDataType
,
InputDataType
,
OutputGridDesc
,
OutputGridDesc
,
OutputDataType
,
OutputDataType
,
Block2ETileMap
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
GridwiseTensorRearrangeKernel
>
;
GridwiseTensorRearrangeKernel
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
...
@@ -237,7 +258,9 @@ struct DeviceImageToColumnImpl
...
@@ -237,7 +258,9 @@ struct DeviceImageToColumnImpl
arg
.
p_in_
,
arg
.
p_in_
,
arg
.
out_grid_desc_m_k_
,
arg
.
out_grid_desc_m_k_
,
arg
.
p_out_
,
arg
.
p_out_
,
block_2_tile_map
);
arg
.
G_
,
block_2_tile_map
,
arg
.
compute_ptr_offset_of_batch_
);
return
elapsed_time
;
return
elapsed_time
;
}
}
...
@@ -250,9 +273,7 @@ struct DeviceImageToColumnImpl
...
@@ -250,9 +273,7 @@ struct DeviceImageToColumnImpl
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
using
namespace
tensor_layout
::
convolution
;
if
constexpr
(
!
(
is_NSpatialGC
||
is_GNSpatialC
))
if
constexpr
(
!
(
std
::
is_same_v
<
ImageLayout
,
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNDHWC
>
))
{
{
return
false
;
return
false
;
}
}
...
@@ -295,13 +316,14 @@ struct DeviceImageToColumnImpl
...
@@ -295,13 +316,14 @@ struct DeviceImageToColumnImpl
static
auto
MakeArgument
(
const
void
*
p_in
,
// input image
static
auto
MakeArgument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// gemm form
void
*
p_out
,
// gemm form
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -309,13 +331,14 @@ struct DeviceImageToColumnImpl
...
@@ -309,13 +331,14 @@ struct DeviceImageToColumnImpl
{
{
return
Argument
{
static_cast
<
const
InputDataType
*>
(
p_in
),
return
Argument
{
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
N
,
C
,
C
,
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
image_g_n_c_wis_strides
,
gemm_m_k_strides
,
gemm_
g_
m_k_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -327,13 +350,14 @@ struct DeviceImageToColumnImpl
...
@@ -327,13 +350,14 @@ struct DeviceImageToColumnImpl
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
// input image
MakeArgumentPointer
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// gemm form
void
*
p_out
,
// gemm form
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
2
>&
gemm_m_k_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_
g_
m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -341,13 +365,14 @@ struct DeviceImageToColumnImpl
...
@@ -341,13 +365,14 @@ struct DeviceImageToColumnImpl
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_in
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
N
,
C
,
C
,
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
image_g_n_c_wis_strides
,
gemm_m_k_strides
,
gemm_
g_
m_k_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
70eebf22
...
@@ -186,6 +186,25 @@ struct Bilinear
...
@@ -186,6 +186,25 @@ struct Bilinear
y
=
type_convert
<
half_t
>
(
alpha_
*
x0
+
beta_
*
ck
::
type_convert
<
float
>
(
x1
));
y
=
type_convert
<
half_t
>
(
alpha_
*
x0
+
beta_
*
ck
::
type_convert
<
float
>
(
x1
));
};
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x0_tmp
=
type_convert
<
float
>
(
x0
);
const
float
x1_tmp
=
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
alpha_
*
x0_tmp
+
beta_
*
x1_tmp
;
y
=
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
alpha_
*
x0
+
beta_
*
x1_tmp
;
y
=
y_tmp
;
};
template
<
>
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
70eebf22
...
@@ -311,6 +311,71 @@ struct AddAddFastGelu
...
@@ -311,6 +311,71 @@ struct AddAddFastGelu
}
}
};
};
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct
ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu
(
const
float
alpha1
=
1.
f
,
const
float
alpha2
=
1.
f
)
:
alpha1_
(
alpha1
),
alpha2_
(
alpha2
)
{
}
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x
=
c
*
alpha1_
+
alpha2_
*
d0
+
d1
;
Relu
{}.
template
operator
()
<
float
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
);
e
=
type_convert
<
half_t
>
(
result
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
bhalf_t
&
c
,
const
bhalf_t
&
d0
,
const
bhalf_t
&
d1
)
const
{
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
);
e
=
type_convert
<
bhalf_t
>
(
result
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int8_t
,
float
,
float
>
(
int8_t
&
e
,
const
int8_t
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
d0
+
d1
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
);
e
=
type_convert
<
int8_t
>
(
result
);
}
const
float
alpha1_
;
const
float
alpha2_
;
};
struct
Normalize
struct
Normalize
{
{
// FIXME: is double absolutely necessary?
// FIXME: is double absolutely necessary?
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
70eebf22
...
@@ -16,6 +16,57 @@ namespace element_wise {
...
@@ -16,6 +16,57 @@ namespace element_wise {
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
#endif
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
// fake conversion
uint16_t
t
=
ck
::
bit_cast
<
uint32_t
>
(
x
);
y
=
ck
::
bit_cast
<
ck
::
f8x2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
float2_t
&
y
,
const
ck
::
float2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
int8x2_t
&
y
,
const
ck
::
int8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf2_t
&
y
,
const
ck
::
bhalf2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
double2_t
&
y
,
const
ck
::
double2_t
&
x
)
const
{
y
=
x
;
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
struct
PassThrough
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -33,6 +84,12 @@ struct PassThrough
...
@@ -33,6 +84,12 @@ struct PassThrough
y
=
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
double
,
float
>
(
double
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -69,6 +126,12 @@ struct PassThrough
...
@@ -69,6 +126,12 @@ struct PassThrough
y
=
type_convert
<
bhalf_t
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
bhalf_t
>
(
float
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
{
...
@@ -207,7 +270,8 @@ struct ConvertF8SR
...
@@ -207,7 +270,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
// check Y datatype
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
@@ -224,6 +288,20 @@ struct Scale
...
@@ -224,6 +288,20 @@ struct Scale
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
half_t
>
(
scale_
)
*
x
;
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
const
float
x_tmp
=
ck
::
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -442,10 +520,11 @@ struct Sigmoid
...
@@ -442,10 +520,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
};
};
...
@@ -455,7 +534,8 @@ struct TanH
...
@@ -455,7 +534,8 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
y
=
ck
::
math
::
tanh
(
x
);
...
@@ -481,7 +561,101 @@ struct Swish
...
@@ -481,7 +561,101 @@ struct Swish
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
};
float
beta_
=
1.0
f
;
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
};
}
// namespace element_wise
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
70eebf22
...
@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
MPadded
;
index_t
MPadded
;
index_t
NPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
KPadded
;
index_t
K0
;
index_t
K0
Padded
;
index_t
k_batch
;
index_t
k_batch
;
Argument
(
const
FloatA
*
p_a_grid_
,
Argument
(
const
FloatA
*
p_a_grid_
,
...
@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
MPadded_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
K0
Padded
_
,
index_t
k_batch_
)
index_t
k_batch_
)
:
p_a_grid
(
p_a_grid_
),
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
p_b_grid
(
p_b_grid_
),
...
@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded
(
MPadded_
),
MPadded
(
MPadded_
),
NPadded
(
NPadded_
),
NPadded
(
NPadded_
),
KPadded
(
KPadded_
),
KPadded
(
KPadded_
),
K0
(
K0
_
),
K0
Padded
(
K0Padded
_
),
k_batch
(
k_batch_
)
k_batch
(
k_batch_
)
{
{
}
}
...
@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<<
"MP:"
<<
MPadded
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"K0:"
<<
K0
<<
", "
<<
"K0
Padded
:"
<<
K0
Padded
<<
", "
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
}
}
};
};
...
@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
}
__host__
__device__
static
auto
CalculateK0
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateK0
Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
{
// k_batch * k0 * k0_per_block * k1
// k_batch * k0 * k0_per_block * k1
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
...
@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
{
auto
K0
=
CalculateK0
(
K
,
K_Batch
);
auto
K0
Padded
=
CalculateK0
Padded
(
K
,
K_Batch
);
return
K_Batch
*
K0
*
K1
;
return
K_Batch
*
K0
Padded
*
K1
;
}
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
...
@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
KBatch
,
index_t
KBatch
,
index_t
K0
,
index_t
K0
Padded
,
index_t
KPad
)
index_t
KPad
)
{
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}
}();
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
pad
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
N
,
index_t
N
,
index_t
StrideB
,
index_t
StrideB
,
index_t
KBatch
,
index_t
KBatch
,
index_t
K0
,
index_t
K0
Padded
,
index_t
KPad
)
index_t
KPad
)
{
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
...
@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}
}();
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_k
pad
_n
,
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
false
;
return
false
;
}
}
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
...
@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
#endif // DEBUG_LOG
return
false
;
return
false
;
}
}
...
@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
std
::
cout
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
<<
") value is not a multiple of
CBlockTransferScalarPerVector_NWaveNPerXDL ("
"
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
#endif // DEBUG_LOG
return
false
;
return
false
;
...
@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
std
::
cout
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
<<
") value is not a multiple of
CBlockTransferScalarPerVector_NWaveNPerXDL ("
"
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
#endif // DEBUG_LOG
return
false
;
return
false
;
}
}
}
}
const
auto
num_k_loop
=
karg
.
K0
/
K0PerBlock
;
const
auto
num_k_loop
=
karg
.
K0
Padded
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
{
#if DEBUG_LOG
#if DEBUG_LOG
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0: "
<<
karg
.
K0
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
__FILE__
<<
" K0
Padded
: "
<<
karg
.
K0
Padded
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
#endif // DEBUG_LOG
return
false
;
return
false
;
}
}
...
@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
K0Padded
=
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0Padded
*
K1
;
return
KPad
;
return
KPad
;
}
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
Padded
)
{
{
const
index_t
num_loop
=
K0
/
K0PerBlock
;
const
index_t
num_loop
=
K0
Padded
/
K0PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
...
@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
Padded
,
karg
.
KPadded
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
Padded
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
70eebf22
...
@@ -21,6 +21,7 @@ template <typename InputGridDesc,
...
@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename
OutputGridDesc
,
typename
OutputGridDesc
,
typename
OutputDataType
,
typename
OutputDataType
,
typename
Block2ETileMap
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseTensorRearrangeKernel
>
typename
GridwiseTensorRearrangeKernel
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -30,13 +31,20 @@ __global__ void
...
@@ -30,13 +31,20 @@ __global__ void
const
InputDataType
*
__restrict__
p_in_global
,
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
out_grid_desc
,
const
OutputGridDesc
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
block_2_tile_map
)
const
index_t
batch_count
,
const
Block2ETileMap
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel
::
Run
(
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
in_grid_desc
,
p_in_global
,
out_grid_desc
,
p_out_global
,
block_2_tile_map
);
p_in_global
,
out_grid_desc
,
p_out_global
,
batch_count
,
block_2_tile_map
,
compute_ptr_offset_of_batch
);
#else
#else
ignore
=
in_grid_desc
;
ignore
=
in_grid_desc
;
ignore
=
p_in_global
;
ignore
=
p_in_global
;
...
@@ -56,7 +64,8 @@ template <typename InputGridDesc,
...
@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename
ThreadClusterLengths
,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
index_t
ScalarPerVector
,
InMemoryDataOperationEnum
DstInMemOp
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
Block2ETileMap
>
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfStridedBatch
>
struct
GridwiseTensorRearrange
struct
GridwiseTensorRearrange
{
{
...
@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
...
@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const
InputDataType
*
__restrict__
p_in_global
,
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
&
out_grid_desc
,
const
OutputGridDesc
&
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
&
block_2_tile_map
)
const
index_t
batch_count
,
const
Block2ETileMap
&
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
&
compute_ptr_offset_of_batch
)
{
{
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
...
@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const
index_t
k_block_data_idx_on_grid
=
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
KPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
KPerBlock
);
// Global Memory
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
auto
copy_global_to_global
=
auto
copy_global_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
InputDataType
>
,
Tuple
<
InputDataType
>
,
...
@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
...
@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
// Global Memory
const
index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
));
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
+
a_batch_offset
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
+
c_batch_offset
,
out_grid_desc
.
GetElementSpaceSize
());
copy_global_to_global
.
Run
(
copy_global_to_global
.
Run
(
tie
(
in_grid_desc
),
tie
(
in_global_buf
),
tie
(
out_grid_desc
),
tie
(
out_global_buf
));
tie
(
in_grid_desc
),
tie
(
in_global_buf
),
tie
(
out_grid_desc
),
tie
(
out_global_buf
));
}
}
...
...
include/ck/utility/data_type.hpp
View file @
70eebf22
...
@@ -1075,6 +1075,7 @@ struct NumericUtils<float>
...
@@ -1075,6 +1075,7 @@ struct NumericUtils<float>
{
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
...
@@ -1091,6 +1092,7 @@ struct NumericUtils<half_t>
...
@@ -1091,6 +1092,7 @@ struct NumericUtils<half_t>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
...
@@ -1107,6 +1109,8 @@ struct NumericUtils<f8_t>
...
@@ -1107,6 +1109,8 @@ struct NumericUtils<f8_t>
{
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
};
template
<
>
template
<
>
...
@@ -1114,6 +1118,7 @@ struct NumericUtils<bf8_t>
...
@@ -1114,6 +1118,7 @@ struct NumericUtils<bf8_t>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
};
//
}
// namespace ck
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
70eebf22
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
namespace
ck
{
namespace
ck
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -17,6 +16,9 @@ enum class f8_rounding_mode
...
@@ -17,6 +16,9 @@ enum class f8_rounding_mode
stochastic
stochastic
};
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
}
// namespace ck
namespace
ck
::
utils
{
namespace
ck
::
utils
{
...
@@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
// nan code is same for float and half
constexpr
Y
nan_code
=
0x80
;
constexpr
Y
nan_code
=
0x80
;
...
@@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
in_exp
-
1
))
-
(
1
<<
(
out_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
...
@@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
// check if x is 0.0
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
return
0
;
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
// First need to check if it is normal or denorm as there is a difference of implict 1
if
(
exponent
<=
0
)
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
+
1
-
exponent
))
-
1
;
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
mantissa
+=
1
<<
in_mant
;
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// apply random number if needed
// exponent and mantissa again3
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
in_mant
))
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
{
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
mantissa
>>=
1
;
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
exponent
++
;
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
// check negative exponent
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
if
(
exponent
<=
0
)
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
{
if
(
x_bitwise
==
0
)
if
((
1
<<
in_mant
)
&
mantissa
)
return
0
;
else
{
{
// subnormal range; represented by a subnormal float8 (exponent 0)
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// and involves loss of accuracy
// No need to make 1 implicit now as it will be addressed later
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
}
}
}
// above range: quantize to maximum possible float of the same sign
else
else
if
(
exponent
>
max_exp
)
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
mantissa
=
(
1
<<
out_mant
)
-
1
;
exponent
=
max_exp
;
out_
exponent
=
max_exp
;
}
}
else
else
{
{
...
@@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
}
// check if x is 0.0 or -0.0
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
out_
exponent
<<
out_mant
)
|
mantissa
;
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
...
@@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
...
@@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent
++
;
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
while
(
mantissa
<
(
1
<<
in_mant
))
mantissa
<<=
sh
;
{
exponent
+=
1
-
sh
;
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
...
...
include/ck/utility/inner_product.hpp
View file @
70eebf22
...
@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
...
@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
#else
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c
=
__builtin_amdgcn_sudot4
(
true
,
bit_cast
<
int32_t
>
(
a
),
true
,
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#else
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
...
...
include/ck/utility/math.hpp
View file @
70eebf22
...
@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
...
@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
static
inline
__host__
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
{
...
...
include/ck/utility/math_v2.hpp
View file @
70eebf22
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
...
@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
...
@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__host__
T
tanh
(
T
x
)
{
{
return
static_cast
<
half_t
>
(
std
::
tanh
(
static_cas
t
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
std
::
tanhf
(
ck
::
type_conver
t
<
float
>
(
x
)));
};
};
static
inline
__host__
float
tanh
(
float
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
inline
__host__
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
tanh
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
inline
__host__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
logf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
inline
__host__
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
inline
__host__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
std
::
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
}
template
<
>
inline
__host__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
inline
__host__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
inline
__host__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
inline
__host__
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
...
@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__device__
T
tanh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
tanhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
tanh
<
float
>
(
float
x
)
{
{
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
))
);
return
::
tanhf
(
x
);
};
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
inline
__device__
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__expf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
{
return
hexp
(
x
);
};
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
};
static
inline
__device__
double
tanh
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
>
inline
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
inline
__device__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__logf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
{
return
hlog
(
x
);
};
template
<
>
inline
__device__
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
inline
__device__
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
inline
__device__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
};
template
<
>
inline
__device__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
inline
__device__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
inline
__device__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
inline
__device__
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
70eebf22
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/utility/type_convert.hpp
View file @
70eebf22
...
@@ -100,6 +100,8 @@ template <>
...
@@ -100,6 +100,8 @@ template <>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
union
{
{
float
fval
;
float
fval
;
...
@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
...
@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
#endif
#endif
}
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
const
vector_type
<
float
,
2
>
f32x2_v
(
x
);
const
auto
y
=
__builtin_amdgcn_cvt_pkrtz
(
f32x2_v
.
template
AsType
<
float
>()[
Number
<
0
>
{}],
f32x2_v
.
template
AsType
<
float
>()[
Number
<
1
>
{}]);
return
bit_cast
<
half2_t
>
(
y
);
}
// convert fp16 to fp8
// convert fp16 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
...
@@ -145,7 +177,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -145,7 +177,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -153,8 +185,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -153,8 +185,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -165,11 +195,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
...
@@ -165,11 +195,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -223,7 +251,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -223,7 +251,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -231,8 +259,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -231,8 +259,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -243,11 +269,9 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
...
@@ -243,11 +269,9 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -347,7 +371,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -347,7 +371,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
@@ -356,8 +380,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -356,8 +380,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -396,7 +418,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -396,7 +418,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
@@ -406,8 +428,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -406,8 +428,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
70eebf22
...
@@ -19,9 +19,7 @@ namespace host {
...
@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image.
* \brief Reference implementation for column to image.
*
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
* \tparam ImageLayout Image Layout.
...
@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
if
(
!
(
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
if
(
!
(
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
input_
.
GetNumOfDimension
()
==
2
))
arg
.
input_
.
GetNumOfDimension
()
==
3
))
{
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
}
const
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
n
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
row
=
n
*
Wo
+
wo
;
...
@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if
(
wi
>=
0
&&
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
{
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
float
v_in
=
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
wi
));
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
row
,
column
));
arg
.
output_
(
0
,
n
,
c
,
wi
)
=
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
c
,
wi
));
arg
.
output_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
}
column
++
;
column
++
;
...
@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
}
};
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
func
,
G
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
...
@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
n
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
...
@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg
.
output_
.
GetLengths
()[
4
])
arg
.
output_
.
GetLengths
()[
4
])
{
{
float
v_in
=
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
hi
,
wi
));
arg
.
output_
(
g
,
n
,
c
,
hi
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
hi
,
wi
)
=
arg
.
output_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
}
column
++
;
column
++
;
...
@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
}
};
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
func
,
G
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
...
@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
n
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
{
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
...
@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg
.
output_
.
GetLengths
()[
5
])
arg
.
output_
.
GetLengths
()[
5
])
{
{
float
v_in
=
ck
::
type_convert
<
float
>
(
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
arg
.
input_
(
g
,
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
di
,
hi
,
wi
));
arg
.
output_
(
g
,
n
,
c
,
di
,
hi
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
di
,
hi
,
wi
)
=
arg
.
output_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
}
column
++
;
column
++
;
...
@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
}
};
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
func
,
G
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
...
@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C
*
ck
::
accumulate_n
<
index_t
>
(
C
*
ck
::
accumulate_n
<
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
NDoHoWo
)
&&
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
arg
.
input_
.
GetLengths
()[
1
]
==
static_cast
<
std
::
size_t
>
(
CZYX
)))
arg
.
input_
.
GetLengths
()[
1
]
==
static_cast
<
std
::
size_t
>
(
NDoHoWo
)
&&
arg
.
input_
.
GetLengths
()[
2
]
==
static_cast
<
std
::
size_t
>
(
CZYX
)))
{
{
return
false
;
return
false
;
}
}
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment