Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dbb7002d
"docs/source/ko/index.mdx" did not exist on "02d83c9ff1b93f2c6f9c94f9369b3e4bc1ba8ce7"
Commit
dbb7002d
authored
Feb 06, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/hotloop
parents
96c8d948
2bef5501
Changes
228
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
145 additions
and
177 deletions
+145
-177
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+3
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+7
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+2
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+3
-3
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+2
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+92
-132
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+12
-5
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -60,8 +60,7 @@ __global__ void
...
@@ -60,8 +60,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -103,7 +102,7 @@ __global__ void
...
@@ -103,7 +102,7 @@ __global__ void
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
0
);
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
index_t
NDimSpatial
,
template
<
index_t
NDimSpatial
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -55,8 +55,7 @@ __global__ void
...
@@ -55,8 +55,7 @@ __global__ void
[[
maybe_unused
]]
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
[[
maybe_unused
]]
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
[[
maybe_unused
]]
const
index_t
num_k_per_block
)
[[
maybe_unused
]]
const
index_t
num_k_per_block
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumGroupsToMerge
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumGroupsToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
...
@@ -85,7 +84,7 @@ __global__ void
...
@@ -85,7 +84,7 @@ __global__ void
k_idx
);
k_idx
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -145,7 +144,7 @@ __global__ void
...
@@ -145,7 +144,7 @@ __global__ void
k_idx
);
k_idx
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -98,8 +99,7 @@ __global__ void
...
@@ -98,8 +99,7 @@ __global__ void
const
ComputePtrOffsetOfG
compute_ptr_offset_of_groups
,
const
ComputePtrOffsetOfG
compute_ptr_offset_of_groups
,
const
ComputePtrOffsetOfN
compute_ptr_offset_of_n
)
const
ComputePtrOffsetOfN
compute_ptr_offset_of_n
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
...
@@ -212,9 +212,13 @@ __global__ void
...
@@ -212,9 +212,13 @@ __global__ void
}
}
}
// namespace
}
// namespace
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
//
//
// @brief Device Convolution operation.
// @brief Device Convolution operation.
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
dbb7002d
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -117,7 +118,7 @@ __global__ void
...
@@ -117,7 +118,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
);
c_grid_desc_mblock_mperblock_nblock_nperblock
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -183,7 +184,7 @@ __global__ void
...
@@ -183,7 +184,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
);
c_grid_desc_mblock_mperblock_nblock_nperblock
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
}
// namespace
}
// namespace
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -155,8 +155,7 @@ __global__ void
...
@@ -155,8 +155,7 @@ __global__ void
const
Block2ETileMap
block_2_ctile_map
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
View file @
dbb7002d
...
@@ -52,8 +52,7 @@ __global__ void
...
@@ -52,8 +52,7 @@ __global__ void
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_n
)
const
ComputePtrOffset
compute_ptr_offset_of_n
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id_x
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
const
index_t
block_id_x
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
dbb7002d
...
@@ -68,8 +68,7 @@ __global__ void
...
@@ -68,8 +68,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
const
CDEElementwiseOperation
cde_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
...
@@ -404,7 +403,7 @@ __global__ void
...
@@ -404,7 +403,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
cde_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -43,8 +43,7 @@ __global__ void
...
@@ -43,8 +43,7 @@ __global__ void
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -109,7 +108,7 @@ __global__ void
...
@@ -109,7 +108,7 @@ __global__ void
ignore
=
acc_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
dbb7002d
...
@@ -38,8 +38,7 @@ __global__ void
...
@@ -38,8 +38,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
dbb7002d
...
@@ -50,8 +50,7 @@ __global__ void
...
@@ -50,8 +50,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
dbb7002d
...
@@ -40,8 +40,7 @@ __global__ void
...
@@ -40,8 +40,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
...
@@ -80,7 +79,7 @@ __global__ void
...
@@ -80,7 +79,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
dbb7002d
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
...
...
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -56,8 +56,7 @@ __global__ void
...
@@ -56,8 +56,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
dbb7002d
...
@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
...
@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
}
// namespace convolution
}
// namespace convolution
#ifndef CK_CODE_GEN_RTC
template
<
template
<
typename
Layout
,
typename
Layout
,
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
...
@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
...
@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os
<<
Layout
::
name
;
os
<<
Layout
::
name
;
return
os
;
return
os
;
}
}
#endif
}
// namespace tensor_layout
}
// namespace tensor_layout
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -340,8 +340,8 @@ struct Bilinear
...
@@ -340,8 +340,8 @@ struct Bilinear
};
};
template
<
>
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
__host__
__device__
constexpr
void
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
operator
()
<
int8_t
,
int32_t
,
int8_t
>
(
int8_t
&
y
,
const
int32_t
&
x0
,
const
int8_t
&
x1
)
const
{
{
y
=
type_convert
<
int8_t
>
(
alpha_
*
type_convert
<
float
>
(
x0
)
+
y
=
type_convert
<
int8_t
>
(
alpha_
*
type_convert
<
float
>
(
x0
)
+
beta_
*
type_convert
<
float
>
(
x1
));
beta_
*
type_convert
<
float
>
(
x1
));
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -533,7 +533,7 @@ struct NormalizeInInfer
...
@@ -533,7 +533,7 @@ struct NormalizeInInfer
const
T3
&
gamma
,
const
T3
&
gamma
,
const
T4
&
beta
)
const
const
T4
&
beta
)
const
{
{
static_assert
(
std
::
is_same
<
T2
,
float
>::
value
||
std
::
is_same
<
T2
,
double
>::
value
,
static_assert
(
is_same
<
T2
,
float
>::
value
||
is_same
<
T2
,
double
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
using
ck
::
type_convert
;
using
ck
::
type_convert
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
dbb7002d
...
@@ -16,7 +16,8 @@ namespace ck {
...
@@ -16,7 +16,8 @@ namespace ck {
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
// Convert lower part of packed int4 -> int4 to half
__device__
inline
half4_t
i4_to_half4
(
int
q
)
{
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
...
@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
...
@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
}
__host__
__device__
inline
half4_t
pk
i4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
__device__
inline
half4_t
i4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
{
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
...
@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
...
@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
}
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
__device__
inline
bhalf4_t
i4_to_bhalf4
(
int
q
)
{
#if 1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
#else
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
vector_type
<
half_t
,
2
>
res
;
half_t
x_h
=
(
x_u8
&
0x0f
)
-
8
;
half_t
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8
;
res
.
template
AsType
<
half_t
>()(
Number
<
0
>
{})
=
x_l
;
res
.
template
AsType
<
half_t
>()(
Number
<
1
>
{})
=
x_h
;
return
res
.
template
AsType
<
half2_t
>()[
Number
<
0
>
{}];
#endif
}
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
int
q
)
{
{
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
...
@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
...
@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
}
__host__
__device__
inline
bhalf2_t
pki4_to_bhalf2
(
pk_i4_t
q
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
float
x_h
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
vector_type
<
bhalf_t
,
2
>
res
;
res
.
template
AsType
<
bhalf_t
>()(
Number
<
0
>
{})
=
type_convert
<
bhalf_t
>
(
x_l
);
res
.
template
AsType
<
bhalf_t
>()(
Number
<
1
>
{})
=
type_convert
<
bhalf_t
>
(
x_h
);
return
res
.
template
AsType
<
bhalf2_t
>()[
Number
<
0
>
{}];
}
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
...
@@ -159,11 +118,11 @@ struct PassThroughPack8
...
@@ -159,11 +118,11 @@ struct PassThroughPack8
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
vector_type
<
half_t
,
8
>
result
;
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pk
i4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
i4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pk
i4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
i4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
#else
...
@@ -171,13 +130,13 @@ struct PassThroughPack8
...
@@ -171,13 +130,13 @@ struct PassThroughPack8
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
#endif
...
@@ -185,11 +144,11 @@ struct PassThroughPack8
...
@@ -185,11 +144,11 @@ struct PassThroughPack8
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
vector_type
<
bhalf_t
,
8
>
result
;
vector_type
<
bhalf_t
,
8
>
result
;
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pk
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pk
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
#else
...
@@ -197,13 +156,13 @@ struct PassThroughPack8
...
@@ -197,13 +156,13 @@ struct PassThroughPack8
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
2
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
2
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
3
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
3
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#endif
#endif
...
@@ -219,12 +178,12 @@ struct DequantPack8
...
@@ -219,12 +178,12 @@ struct DequantPack8
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
vector_type
<
half_t
,
8
>
result
;
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pk
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pk
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
#else
...
@@ -232,13 +191,13 @@ struct DequantPack8
...
@@ -232,13 +191,13 @@ struct DequantPack8
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
#endif
...
@@ -252,7 +211,7 @@ struct PassThroughPack2
...
@@ -252,7 +211,7 @@ struct PassThroughPack2
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
half2_t
&
y
,
const
f8x2_t
&
x
)
const
{
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
y
=
type_convert
<
half2_t
>
(
t
);
...
@@ -260,7 +219,7 @@ struct PassThroughPack2
...
@@ -260,7 +219,7 @@ struct PassThroughPack2
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
...
@@ -479,7 +438,7 @@ struct PassThrough
...
@@ -479,7 +438,7 @@ struct PassThrough
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
half_t
>
(
bf8_t
&
y
,
const
half_t
&
x
)
const
__host__
__device__
void
operator
()
<
bf8_t
,
half_t
>
(
bf8_t
&
y
,
const
half_t
&
x
)
const
{
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
type_convert
<
bf8_t
>
(
x
);
}
}
};
};
...
@@ -552,21 +511,21 @@ struct Scale
...
@@ -552,21 +511,21 @@ struct Scale
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
y
=
ck
::
type_convert
<
Y
>
(
ck
::
type_convert
<
float
>
(
x
)
*
scale_
);
y
=
type_convert
<
Y
>
(
type_convert
<
float
>
(
x
)
*
scale_
);
}
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
{
y
=
ck
::
type_convert
<
half_t
>
(
scale_
)
*
x
;
y
=
type_convert
<
half_t
>
(
scale_
)
*
x
;
};
};
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
{
const
float
x_tmp
=
ck
::
type_convert
<
float
>
(
x
);
const
float
x_tmp
=
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
y
=
type_convert
<
bhalf_t
>
(
y_tmp
);
};
};
template
<
>
template
<
>
...
@@ -584,7 +543,7 @@ struct Scale
...
@@ -584,7 +543,7 @@ struct Scale
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
{
y
=
ck
::
type_convert
<
int8_t
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x
));
y
=
type_convert
<
int8_t
>
(
scale_
*
type_convert
<
float
>
(
x
));
};
};
float
scale_
;
float
scale_
;
...
@@ -600,7 +559,7 @@ struct ScaleAndResetNaNToMinusInfinity
...
@@ -600,7 +559,7 @@ struct ScaleAndResetNaNToMinusInfinity
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
y
=
ck
::
math
::
isnan
(
x
)
?
-
ck
::
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
y
=
math
::
isnan
(
x
)
?
-
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
};
};
float
scale_
;
float
scale_
;
...
@@ -671,12 +630,13 @@ struct UnaryAbs
...
@@ -671,12 +630,13 @@ struct UnaryAbs
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
abs
(
x
);
y
=
math
::
abs
(
x
);
};
};
template
<
>
template
<
>
...
@@ -694,7 +654,7 @@ struct UnarySqrt
...
@@ -694,7 +654,7 @@ struct UnarySqrt
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sqrt
(
x
);
y
=
math
::
sqrt
(
x
);
};
};
};
};
...
@@ -713,9 +673,9 @@ struct Relu
...
@@ -713,9 +673,9 @@ struct Relu
template
<
>
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
x_f32
=
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
y
=
type_convert
<
bhalf_t
>
(
y_f32
);
}
}
};
};
...
@@ -731,7 +691,7 @@ struct FastGelu
...
@@ -731,7 +691,7 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
#ifndef CK_CODE_GEN_RTC
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -742,6 +702,7 @@ struct FastGelu
...
@@ -742,6 +702,7 @@ struct FastGelu
const
float
emu
=
exp
(
u
);
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
y
=
x
/
(
1.
f
+
emu
);
}
}
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp"
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
template
<
>
...
@@ -753,7 +714,7 @@ struct FastGelu
...
@@ -753,7 +714,7 @@ struct FastGelu
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__ocml_exp_f32
(
u
);
const
float
emu
=
__ocml_exp_f32
(
u
);
y
=
x
*
ck
::
math
::
rcp
(
1.
f
+
emu
);
y
=
x
*
math
::
rcp
(
1.
f
+
emu
);
}
}
template
<
>
template
<
>
...
@@ -851,10 +812,9 @@ struct Gelu
...
@@ -851,10 +812,9 @@ struct Gelu
}
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
y
,
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
const
ck
::
half_t
&
x
)
const
{
{
y
=
ck
::
half_t
(
0.5
)
*
x
*
(
ck
::
half_t
(
1
)
+
ck
::
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
y
=
half_t
(
0.5
)
*
x
*
(
half_t
(
1
)
+
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
}
};
};
...
@@ -868,7 +828,7 @@ struct Sigmoid
...
@@ -868,7 +828,7 @@ struct Sigmoid
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
y
=
one
/
(
one
+
math
::
exp
(
-
x
));
};
};
};
};
...
@@ -877,11 +837,11 @@ struct Silu
...
@@ -877,11 +837,11 @@ struct Silu
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
ck
::
half_t
>
||
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
half_t
>
||
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck
::
math
::
exp
(
-
x
)));
y
=
x
*
(
one
/
(
one
+
math
::
exp
(
-
x
)));
};
};
};
};
...
@@ -895,7 +855,7 @@ struct TanH
...
@@ -895,7 +855,7 @@ struct TanH
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
y
=
math
::
tanh
(
x
);
};
};
};
};
...
@@ -905,11 +865,11 @@ struct ACos
...
@@ -905,11 +865,11 @@ struct ACos
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acos
(
x
);
y
=
math
::
acos
(
x
);
};
};
};
};
...
@@ -919,11 +879,11 @@ struct Neg
...
@@ -919,11 +879,11 @@ struct Neg
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
neg
(
x
);
y
=
math
::
neg
(
x
);
};
};
};
};
...
@@ -933,11 +893,11 @@ struct ATan
...
@@ -933,11 +893,11 @@ struct ATan
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atan
(
x
);
y
=
math
::
atan
(
x
);
};
};
};
};
...
@@ -947,11 +907,11 @@ struct Sin
...
@@ -947,11 +907,11 @@ struct Sin
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sin
(
x
);
y
=
math
::
sin
(
x
);
};
};
};
};
...
@@ -961,11 +921,11 @@ struct ASinH
...
@@ -961,11 +921,11 @@ struct ASinH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asinh
(
x
);
y
=
math
::
asinh
(
x
);
};
};
};
};
...
@@ -975,11 +935,11 @@ struct Cos
...
@@ -975,11 +935,11 @@ struct Cos
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cos
(
x
);
y
=
cos
(
x
);
};
};
};
};
...
@@ -989,11 +949,11 @@ struct ACosH
...
@@ -989,11 +949,11 @@ struct ACosH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acosh
(
x
);
y
=
math
::
acosh
(
x
);
};
};
};
};
...
@@ -1003,11 +963,11 @@ struct Tan
...
@@ -1003,11 +963,11 @@ struct Tan
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tan
(
x
);
y
=
math
::
tan
(
x
);
};
};
};
};
...
@@ -1017,11 +977,11 @@ struct ATanH
...
@@ -1017,11 +977,11 @@ struct ATanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atanh
(
x
);
y
=
math
::
atanh
(
x
);
};
};
};
};
...
@@ -1031,11 +991,11 @@ struct SinH
...
@@ -1031,11 +991,11 @@ struct SinH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sinh
(
x
);
y
=
math
::
sinh
(
x
);
};
};
};
};
...
@@ -1045,11 +1005,11 @@ struct Ceil
...
@@ -1045,11 +1005,11 @@ struct Ceil
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
ceil
(
x
);
y
=
math
::
ceil
(
x
);
};
};
};
};
...
@@ -1059,11 +1019,11 @@ struct Exp
...
@@ -1059,11 +1019,11 @@ struct Exp
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
exp
(
x
);
y
=
math
::
exp
(
x
);
};
};
};
};
...
@@ -1073,11 +1033,11 @@ struct CosH
...
@@ -1073,11 +1033,11 @@ struct CosH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cosh
(
x
);
y
=
math
::
cosh
(
x
);
};
};
};
};
...
@@ -1087,11 +1047,11 @@ struct Floor
...
@@ -1087,11 +1047,11 @@ struct Floor
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
floor
(
x
);
y
=
math
::
floor
(
x
);
};
};
};
};
...
@@ -1101,11 +1061,11 @@ struct Log
...
@@ -1101,11 +1061,11 @@ struct Log
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
log
(
x
);
y
=
math
::
log
(
x
);
};
};
};
};
...
@@ -1115,11 +1075,11 @@ struct ASin
...
@@ -1115,11 +1075,11 @@ struct ASin
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asin
(
x
);
y
=
math
::
asin
(
x
);
};
};
};
};
...
@@ -1129,11 +1089,11 @@ struct Rcp
...
@@ -1129,11 +1089,11 @@ struct Rcp
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
rcp
(
x
);
y
=
math
::
rcp
(
x
);
};
};
};
};
...
@@ -1153,7 +1113,7 @@ struct Swish
...
@@ -1153,7 +1113,7 @@ struct Swish
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
math
::
exp
(
bx
)));
};
};
const
float
beta_
;
const
float
beta_
;
...
@@ -1172,7 +1132,7 @@ struct SoftRelu
...
@@ -1172,7 +1132,7 @@ struct SoftRelu
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
y
=
math
::
log
(
one
+
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
}
const
float
alpha_
;
const
float
alpha_
;
};
};
...
@@ -1193,7 +1153,7 @@ struct Power
...
@@ -1193,7 +1153,7 @@ struct Power
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
y
=
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
}
const
float
alpha_
;
const
float
alpha_
;
const
float
beta_
;
const
float
beta_
;
...
@@ -1213,7 +1173,7 @@ struct ClippedRelu
...
@@ -1213,7 +1173,7 @@ struct ClippedRelu
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
y
=
math
::
min
(
casted_beta
,
math
::
max
(
casted_alpha
,
x
));
}
}
const
float
alpha_
;
const
float
alpha_
;
const
float
beta_
;
const
float
beta_
;
...
@@ -1248,7 +1208,7 @@ struct Elu
...
@@ -1248,7 +1208,7 @@ struct Elu
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
y
=
x
>
0
?
x
:
casted_alpha
*
math
::
expm1
(
x
);
}
}
const
float
alpha_
;
const
float
alpha_
;
};
};
...
@@ -1350,10 +1310,10 @@ struct FastNumericArrayConverter
...
@@ -1350,10 +1310,10 @@ struct FastNumericArrayConverter
};
};
template
<
>
template
<
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck
::
half_t
,
4
>
struct
FastNumericArrayConverter
<
uint8_t
,
half_t
,
4
>
{
{
using
InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
4
>
;
using
OutputArray
=
vector_type
<
half_t
,
4
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
{
...
@@ -1383,13 +1343,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
...
@@ -1383,13 +1343,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
};
};
template
<
index_t
N
>
template
<
index_t
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck
::
half_t
,
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
half_t
,
N
>
{
{
static
constexpr
int
VEC_WIDTH
=
4
;
static
constexpr
int
VEC_WIDTH
=
4
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
using
InputArray
=
vector_type
<
uint8_t
,
N
>
;
using
InputArray
=
vector_type
<
uint8_t
,
N
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
N
>
;
using
OutputArray
=
vector_type
<
half_t
,
N
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
{
...
@@ -1398,7 +1358,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
...
@@ -1398,7 +1358,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
OutputArray
Output
;
OutputArray
Output
;
using
Vec_InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
Vec_InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
ck
::
half_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
half_t
,
4
>
;
Vec_OutputArray
*
half_4_ptr
=
reinterpret_cast
<
Vec_OutputArray
*>
(
&
Output
);
Vec_OutputArray
*
half_4_ptr
=
reinterpret_cast
<
Vec_OutputArray
*>
(
&
Output
);
Vec_InputArray
const
*
uint8_4_ptr
=
reinterpret_cast
<
Vec_InputArray
const
*>
(
&
Input
);
Vec_InputArray
const
*
uint8_4_ptr
=
reinterpret_cast
<
Vec_InputArray
const
*>
(
&
Input
);
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef CK_CODE_GEN_RTC
#include <limits>
#include <limits>
#include <stdlib.h>
#include <stdlib.h>
#endif
namespace
ck
{
namespace
ck
{
...
@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
// Create 3D grid
// Create 3D grid
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
make_tuple
(
N0
,
M0
,
k_split
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
}
}
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
...
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
uint32_t
best_sk_score
=
std
::
n
umeric
_l
imits
<
int
>::
m
ax
();
// we need to find the smallest sk iters
N
umeric
L
imits
<
int
32_t
>::
M
ax
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
tentative_sk_blocks
++
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
dbb7002d
...
@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
dbb7002d
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math
::
lcm
(
// selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
,
//
B1K1
),
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
;
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment