Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c16f789d
Commit
c16f789d
authored
Apr 18, 2022
by
rocking
Browse files
Merge remote-tracking branch 'origin/develop' into gemm_softmax
parents
21802fda
4221505d
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
232 additions
and
77 deletions
+232
-77
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+20
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+15
-0
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+18
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+18
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+13
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+21
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+13
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+13
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+13
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+15
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+17
-0
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+1
-0
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+0
-71
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+8
-5
include/ck/utility/generic_memory_space_atomic_add.hpp
include/ck/utility/generic_memory_space_atomic_add.hpp
+44
-0
script/cmake-rocm.sh
script/cmake-rocm.sh
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
c16f789d
...
@@ -54,6 +54,7 @@ __global__ void
...
@@ -54,6 +54,7 @@ __global__ void
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -88,6 +89,25 @@ __global__ void
...
@@ -88,6 +89,25 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_d1_grid
;
ignore
=
batch_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
d0_reduce_op
;
ignore
=
d1_reduce_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d_grid_desc_mblock_mperblock
;
ignore
=
compute_base_ptr_of_batch_
;
ignore
=
block_2_ctile_map
;
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
c16f789d
...
@@ -46,6 +46,7 @@ __global__ void
...
@@ -46,6 +46,7 @@ __global__ void
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -70,6 +71,20 @@ __global__ void
...
@@ -70,6 +71,20 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
compute_base_ptr_of_batch_
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
c16f789d
...
@@ -49,6 +49,7 @@ __global__ void
...
@@ -49,6 +49,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
num_batches
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
num_batches
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -73,6 +74,23 @@ __global__ void
...
@@ -73,6 +74,23 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
num_batches
;
ignore
=
a_batch_stride
;
ignore
=
b_batch_stride
;
ignore
=
c_batch_stride
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
c16f789d
...
@@ -48,6 +48,7 @@ __global__ void
...
@@ -48,6 +48,7 @@ __global__ void
const
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock
,
const
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
...
@@ -66,6 +67,23 @@ __global__ void
...
@@ -66,6 +67,23 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_d1_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
d0_reduce_op
;
ignore
=
d1_reduce_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
c16f789d
...
@@ -38,6 +38,7 @@ __global__ void
...
@@ -38,6 +38,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
...
@@ -51,6 +52,18 @@ __global__ void
...
@@ -51,6 +52,18 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
c16f789d
...
@@ -39,6 +39,7 @@ __global__ void
...
@@ -39,6 +39,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
...
@@ -52,6 +53,18 @@ __global__ void
...
@@ -52,6 +53,18 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -74,6 +87,7 @@ __global__ void
...
@@ -74,6 +87,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -126,6 +140,13 @@ __global__ void
...
@@ -126,6 +140,13 @@ __global__ void
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
block_id_grp
);
#endif
#endif
#else
ignore
=
gemm_desc_
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
c16f789d
...
@@ -37,6 +37,7 @@ __global__ void
...
@@ -37,6 +37,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -53,6 +54,18 @@ __global__ void
...
@@ -53,6 +54,18 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_block_cluster_adaptor
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
c16f789d
...
@@ -39,6 +39,7 @@ __global__ void
...
@@ -39,6 +39,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -55,6 +56,18 @@ __global__ void
...
@@ -55,6 +56,18 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_block_cluster_adaptor
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
c16f789d
...
@@ -42,6 +42,7 @@ __global__ void
...
@@ -42,6 +42,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
@@ -56,6 +57,18 @@ __global__ void
...
@@ -56,6 +57,18 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
template
<
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
c16f789d
...
@@ -45,6 +45,7 @@ __global__ void
...
@@ -45,6 +45,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
@@ -61,6 +62,20 @@ __global__ void
...
@@ -61,6 +62,20 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c0_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
ignore
=
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
template
<
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
c16f789d
...
@@ -49,6 +49,7 @@ __global__ void
...
@@ -49,6 +49,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
@@ -67,6 +68,22 @@ __global__ void
...
@@ -67,6 +68,22 @@ __global__ void
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c0_grid
;
ignore
=
p_c1_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
ignore
=
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
ignore
=
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
template
<
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
c16f789d
...
@@ -36,6 +36,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
...
@@ -36,6 +36,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
DataType
value
)
DataType
value
)
{
{
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
DataType
,
DataType
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
DataType
,
DataType
>
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
include/ck/utility/common_header.hpp
View file @
c16f789d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "functional3.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#include "functional4.hpp"
#include "enable_if.hpp"
#include "enable_if.hpp"
#include "ignore.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "math.hpp"
#include "math.hpp"
#include "number.hpp"
#include "number.hpp"
...
@@ -30,6 +31,7 @@
...
@@ -30,6 +31,7 @@
#include "debug.hpp"
#include "debug.hpp"
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp"
#include "get_id.hpp"
#include "get_id.hpp"
#include "synchronization.hpp"
#include "synchronization.hpp"
#include "amd_address_space.hpp"
#include "amd_address_space.hpp"
...
...
include/ck/utility/data_type.hpp
View file @
c16f789d
...
@@ -992,77 +992,6 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, float>(float x)
...
@@ -992,77 +992,6 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, float>(float x)
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
// TODO: deprecate this
template
<
typename
T
>
struct
inner_product_with_conversion
{
template
<
typename
X
,
index_t
N
>
__device__
T
operator
()(
typename
vector_type
<
X
,
N
>::
type
a
,
typename
vector_type
<
X
,
N
>::
type
b
)
const
{
const
vector_type
<
X
,
N
>
a_vector
{
a
};
const
vector_type
<
X
,
N
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
Scalars
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
Scalars
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
float_t
a
,
float_t
b
)
const
{
return
type_convert
<
T
>
(
a
)
*
type_convert
<
T
>
(
b
);
}
__device__
T
operator
()(
int8x4_t
a
,
int8x4_t
b
)
const
{
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
int8x8_t
a
,
int8x8_t
b
)
const
{
const
vector_type
<
int8_t
,
8
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
8
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
int8x16_t
a
,
int8x16_t
b
)
const
{
const
vector_type
<
int8_t
,
16
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
16
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
c16f789d
#pragma once
#pragma once
#include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
#include "config.hpp"
#include "config.hpp"
#include "enable_if.hpp"
#include "enable_if.hpp"
#include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp"
namespace
ck
{
namespace
ck
{
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
template
<
AddressSpaceEnum
BufferAddressSpace
,
template
<
AddressSpaceEnum
BufferAddressSpace
,
typename
T
,
typename
T
,
typename
ElementSpaceSize
,
typename
ElementSpaceSize
,
...
@@ -316,9 +321,7 @@ struct DynamicBuffer
...
@@ -316,9 +321,7 @@ struct DynamicBuffer
{
{
if
(
is_valid_element
)
if
(
is_valid_element
)
{
{
// FIXME: atomicAdd is defined by HIP, need to avoid implicit type casting when
atomic_add
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
// calling it
atomicAdd
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
}
}
}
}
...
...
include/ck/utility/generic_memory_space_atomic_add.hpp
0 → 100644
View file @
c16f789d
#pragma once
#include "data_type.hpp"
namespace
ck
{
template
<
typename
X
>
__device__
X
atomic_add
(
X
*
p_dst
,
const
X
&
x
);
template
<
>
__device__
int32_t
atomic_add
<
int32_t
>
(
int32_t
*
p_dst
,
const
int32_t
&
x
)
{
return
atomicAdd
(
p_dst
,
x
);
}
template
<
>
__device__
uint32_t
atomic_add
<
uint32_t
>
(
uint32_t
*
p_dst
,
const
uint32_t
&
x
)
{
return
atomicAdd
(
p_dst
,
x
);
}
template
<
>
__device__
float
atomic_add
<
float
>
(
float
*
p_dst
,
const
float
&
x
)
{
return
atomicAdd
(
p_dst
,
x
);
}
template
<
>
__device__
float2_t
atomic_add
<
float2_t
>
(
float2_t
*
p_dst
,
const
float2_t
&
x
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
vector_type
<
float
,
2
>
vx
{
x
};
vector_type
<
float
,
2
>
vy
{
0
};
vy
.
template
AsType
<
float
>()(
I0
)
=
atomicAdd
(
c_style_pointer_cast
<
float
*>
(
p_dst
),
vx
.
template
AsType
<
float
>()[
I0
]);
vy
.
template
AsType
<
float
>()(
I1
)
=
atomicAdd
(
c_style_pointer_cast
<
float
*>
(
p_dst
)
+
1
,
vx
.
template
AsType
<
float
>()[
I1
]);
return
vy
.
template
AsType
<
float2_t
>()[
I0
];
}
}
// namespace ck
script/cmake-rocm.sh
View file @
c16f789d
...
@@ -10,7 +10,7 @@ cmake
...
@@ -10,7 +10,7 @@ cmake
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
BUILD_DEV
=
OFF
\
-D
BUILD_DEV
=
OFF
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_CXX_FLAGS
=
"
--offload-arch=gfx908 --offload-arch=gfx90a
-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_FLAGS
=
" -O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment