Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4e911f3e
Commit
4e911f3e
authored
Jul 07, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
fa9da1a4
4939ee59
Changes
139
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1701 additions
and
115 deletions
+1701
-115
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp
...antization/conv2d_fwd_dl_perchannel_quantization_int8.cpp
+1
-1
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
...quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
+1
-1
example/41_grouped_conv_conv_fwd/CMakeLists.txt
example/41_grouped_conv_conv_fwd/CMakeLists.txt
+2
-6
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
+1
-1
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
+1
-1
include/ck/ck.hpp
include/ck/ck.hpp
+23
-0
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+13
-11
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+0
-2
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+185
-81
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
...pu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
+714
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp
...gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+6
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+48
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
...norm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
+704
-0
No files found.
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp
View file @
4e911f3e
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using
InDataType
=
int8_t
;
using
InDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
...
...
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
View file @
4e911f3e
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using
InDataType
=
int8_t
;
using
InDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
...
...
example/41_grouped_conv_conv_fwd/CMakeLists.txt
View file @
4e911f3e
...
@@ -13,10 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -13,10 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endif
()
endforeach
()
endforeach
()
set
(
target 0
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list2 AND target EQUAL 0
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
View file @
4e911f3e
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
View file @
4e911f3e
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
include/ck/ck.hpp
View file @
4e911f3e
...
@@ -27,6 +27,21 @@
...
@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2
#define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif
#endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
...
@@ -148,6 +163,10 @@
...
@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2
// experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...
@@ -173,6 +192,10 @@
...
@@ -173,6 +192,10 @@
// workaround: compiler issue on gfx908
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_SWDEV_3318619 0
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 0
...
...
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
View file @
4e911f3e
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
static
constexpr
auto
thread_cluster_desc
=
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
CountDataType
>
__device__
static
inline
void
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
Merge
(
T
&
mean_a
,
T
&
var_a
,
CountDataType
&
count_a
,
T
mean_b
,
T
var_b
,
CountDataType
count_b
)
{
{
int
count
=
count_a
+
count_b
;
CountDataType
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
mean_a
+=
delta
*
count_b_over_count
;
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a
=
count
;
count_a
=
count
;
}
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
template
<
typename
CountDataType
>
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
CountDataType
&
count
)
{
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
__shared__
CountDataType
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
CountDataType
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
CountDataType
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
4e911f3e
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4e911f3e
...
@@ -786,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -786,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
==
1
&&
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
==
1
&&
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
0
]
%
D0sTransferSrcScalarPerVector
!=
0
)
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
0
]
%
D0sTransferSrcScalarPerVector
!=
0
)
{
{
std
::
cout
<<
"first"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
!=
1
&&
D0sTransferSrcScalarPerVector
!=
1
)
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
!=
1
&&
D0sTransferSrcScalarPerVector
!=
1
)
{
{
std
::
cout
<<
"second"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
4e911f3e
...
@@ -10,12 +10,14 @@
...
@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
{
const
auto
grid_desc_m_g
=
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
make_
tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
make_
tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 1
28
// we want the blkGroupSize be not more than 1
6
if
(
testBlkGroupSize
<=
1
28
)
if
(
testBlkGroupSize
<=
1
6
)
break
;
break
;
iterations
++
;
iterations
++
;
...
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void
*
workspace_mean_
;
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
void
*
workspace_count_
;
void
*
control_
;
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
...
@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count
// workspace for welford intermediate count
workspace_size
+=
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size
+=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
}
}
return
(
workspace_size
);
return
(
workspace_size
);
...
@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
{
// setup buffer used for intermediate welford mean
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
...
@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
index_t
count_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
);
count_space_sz
=
math
::
integer_least_multiple
(
count_space_sz
,
64
);
pArg_
->
control_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_count_
)
+
count_space_sz
;
index_t
control_space_sz
=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
hip_check_error
(
hipMemset
(
pArg_
->
control_
,
0
,
control_space_sz
));
};
};
};
};
...
@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockBatchNormForward_
=
GridwiseMultiblockBatchNormForward
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
using
GridwiseMultiblockWelfordFirstHalf_
=
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
MeanVarSrcDstVectorSize
>
;
index_t
numMeanVarCountBlockTileIteration
=
// It is found that:
(
arg
.
blkGroupSize_
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
const
auto
kern_multiblock_welford_first_half
=
// 2) Profiler on gfx908 could hang even though it works when running examples
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
// 3) Single-kernel method works on gfx1100, but the performance it not better
XDataType
,
// than two-kernel method (due to more warps participating the barrier)
MeanVarDataType
,
if
(
ck
::
get_device_name
()
==
"gfx90a"
)
XYGridDesc_M_K
,
{
MeanVarCountGridDesc_M_G
,
const
auto
kern_multiblock_batchnorm_fwd_
=
GetReduceCountPerThreadFunctor
>
;
kernel_multiblock_batchnorm_forward
<
GridwiseMultiblockBatchNormForward_
,
XDataType
,
const
auto
kern_welford_second_half_batchnorm_forward_final
=
YDataType
,
kernel_welford_second_half_batchnorm_forward_final
<
AccDataType
,
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
ScaleDataType
,
XDataType
,
BiasDataType
,
YDataType
,
MeanVarDataType
,
AccDataType
,
YElementwiseOp
,
ScaleDataType
,
XYGridDesc_M_K
,
BiasDataType
,
MeanVarCountGridDesc_M_G
,
MeanVarDataType
,
MeanVarCountGridDesc_M_K
,
YElementwiseOp
,
ScaleBiasMeanVarGridDesc_M
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
GetReduceCountPerThreadFunctor
>
;
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
kern_multiblock_batchnorm_fwd_
,
launch_and_time_kernel
(
stream_config
,
dim3
(
arg
.
gridSize_
),
kern_multiblock_welford_first_half
,
dim3
(
BlockSize
),
dim3
(
arg
.
gridSize_
),
0
,
dim3
(
BlockSize
),
arg
.
x_grid_desc_m_k_
,
0
,
arg
.
y_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
// for writing to mean/variance/count
mean_var_count_grid_desc_m_g
,
// workspace by multiple workgroups
get_reduce_count_per_thread
,
mean_var_count_grid_desc_m_k
,
// for reading from mean/variance/count
arg
.
numBlockTileIteration_
,
// workspace by each workgroup
arg
.
p_x_
,
arg
.
scale_grid_desc_m_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
arg
.
bias_grid_desc_m_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
arg
.
mean_var_grid_desc_m_
,
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
avg_time
+=
arg
.
epsilon_
,
launch_and_time_kernel
(
stream_config
,
arg
.
p_x_
,
kern_welford_second_half_batchnorm_forward_final
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
dim3
(
arg
.
gridSize_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
dim3
(
BlockSize
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
0
,
static_cast
<
int
*>
(
arg
.
control_
),
arg
.
x_grid_desc_m_k_
,
arg
.
p_scale_
,
arg
.
y_grid_desc_m_k_
,
arg
.
p_bias_
,
mean_var_count_grid_desc_m_k
,
arg
.
y_elementwise_op_
,
arg
.
scale_grid_desc_m_
,
arg
.
p_y_
,
arg
.
bias_grid_desc_m_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
mean_var_grid_desc_m_
,
arg
.
averageFactor_
,
arg
.
blkGroupSize_
,
arg
.
resultRunningMean_
,
arg
.
numBlockTileIteration_
,
arg
.
resultRunningVariance_
,
numMeanVarCountBlockTileIteration
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
epsilon_
,
arg
.
resultSaveMean_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
arg
.
resultSaveInvVariance_
);
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
}
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
else
arg
.
p_x_
,
{
arg
.
p_scale_
,
const
auto
kern_multiblock_welford_first_half
=
arg
.
p_bias_
,
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
arg
.
y_elementwise_op_
,
XDataType
,
arg
.
p_y_
,
MeanVarDataType
,
arg
.
updateMovingAverage_
,
XYGridDesc_M_K
,
arg
.
averageFactor_
,
MeanVarCountGridDesc_M_G
,
arg
.
resultRunningMean_
,
GetReduceCountPerThreadFunctor
>
;
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
const
auto
kern_welford_second_half_batchnorm_forward_final
=
arg
.
resultSaveMean_
,
kernel_welford_second_half_batchnorm_forward_final
<
arg
.
resultSaveInvVariance_
);
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
0 → 100644
View file @
4e911f3e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
4e911f3e
...
@@ -76,7 +76,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -76,7 +76,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
// TODO: should be exposed as Tparams.
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
LoopScheduler
LoopSched
=
make_default_loop_scheduler
();
static
constexpr
LoopScheduler
LoopSched
=
make_default_loop_scheduler
();
static
constexpr
PipelineVersion
PipelineVer
=
PipelineVersion
::
v
2
;
static
constexpr
PipelineVersion
PipelineVer
=
PipelineVersion
::
v
1
;
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
4e911f3e
...
@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_k_wos_lengths
[
0
]},
num_group_
{
a_g_n_k_wos_lengths
[
0
]},
num_gemm_
{},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
},
...
@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
// number of GEMM
num_gemm_
=
YTilde
*
XTilde
;
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
...
@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
void
Print
()
const
void
Print
()
const
{
{
for
(
index
_t
i
=
0
;
i
<
num_gemm_
;
i
++
)
for
(
std
::
size
_t
i
=
0
;
i
<
a_grid_desc_ak0_m_ak1_container_
.
size
()
;
i
++
)
{
{
std
::
cout
<<
"a_grid_desc_ak0_m_ak1_container_"
std
::
cout
<<
"a_grid_desc_ak0_m_ak1_container_"
<<
a_grid_desc_ak0_m_ak1_container_
[
i
]
<<
std
::
endl
;
<<
a_grid_desc_ak0_m_ak1_container_
[
i
]
<<
std
::
endl
;
...
@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition
// tensor descriptor for problem definition
index_t
num_group_
;
index_t
num_group_
;
index_t
num_gemm_
;
std
::
vector
<
AGridDesc_M_K
>
a_grid_desc_m_k_container_
;
std
::
vector
<
AGridDesc_M_K
>
a_grid_desc_m_k_container_
;
std
::
vector
<
BGridDesc_N_K
>
b_grid_desc_n_k_container_
;
std
::
vector
<
BGridDesc_N_K
>
b_grid_desc_n_k_container_
;
std
::
vector
<
DsGridDesc_M_N
>
ds_grid_desc_m_n_container_
;
std
::
vector
<
DsGridDesc_M_N
>
ds_grid_desc_m_n_container_
;
...
@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float
ave_time
=
0
;
float
ave_time
=
0
;
for
(
index
_t
i
=
0
;
i
<
arg
.
num_gemm_
;
i
++
)
for
(
std
::
size
_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
()
;
i
++
)
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
...
@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
// vector load for A matrix from global memory to LDS
// vector load for A matrix from global memory to LDS
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
{
{
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
// vector store for E
// vector store for E
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
)
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
)
{
{
// vector store C matrix into global memory
// vector store C matrix into global memory
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
4e911f3e
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -81,6 +82,36 @@ struct PassThrough
...
@@ -81,6 +82,36 @@ struct PassThrough
y
=
x
;
y
=
x
;
}
}
#endif
#endif
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
f8_t
>
(
float
&
y
,
const
f8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
f8_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
f8_t
>
(
half_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
half_t
>
(
f8_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert
<
f8_t
>
(
x
);
}
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -109,6 +140,23 @@ struct ConvertBF16RTN
...
@@ -109,6 +140,23 @@ struct ConvertBF16RTN
}
}
};
};
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
struct
Scale
struct
Scale
{
{
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
0 → 100644
View file @
4e911f3e
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment