Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d39c3f5d
Commit
d39c3f5d
authored
Jun 06, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
88b978c5
ac58cc5d
Changes
120
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1804 additions
and
61 deletions
+1804
-61
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+12
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+1
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+1118
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+23
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+84
-21
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+3
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+71
-3
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
+175
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+10
-1
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+7
-6
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+1
-1
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+26
-4
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+1
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+60
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+55
-0
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+89
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+32
-4
include/ck_tile/host/reference/reference_batched_dropout.hpp
include/ck_tile/host/reference/reference_batched_dropout.hpp
+33
-0
No files found.
example/ck_tile/01_fmha/script/smoke_test.sh
→
example/ck_tile/01_fmha/script/smoke_test
_fwd
.sh
View file @
d39c3f5d
...
@@ -17,17 +17,18 @@ for perm in 0 1 ; do
...
@@ -17,17 +17,18 @@ for perm in 0 1 ; do
for
vlayout
in
"r"
"c"
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias
-p_drop=$p_drop
-lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
done
...
@@ -37,6 +38,7 @@ done
...
@@ -37,6 +38,7 @@ done
done
done
done
done
done
done
done
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
bias
in
"n"
"e"
"a"
;
do
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
d39c3f5d
...
@@ -674,7 +674,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -674,7 +674,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
clear_workspace
();
clear_workspace
();
};
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
ave_time
+
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
stream_config
,
run_flush_cache
,
run_flush_cache
,
kernel
,
kernel
,
...
@@ -690,7 +690,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -690,7 +690,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
else
else
{
{
ave_time
=
launch_and_time_kernel_with_preprocess
(
ave_time
+
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
stream_config
,
clear_workspace
,
clear_workspace
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
d39c3f5d
...
@@ -820,15 +820,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -820,15 +820,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
false
;
return
false
;
}
}
}
}
else
if
(
ck
::
is_lds_direct_load_supported
())
if
(
!
ck
::
is_xdl_supported
())
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
0 → 100644
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
index_t
batch_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
AGridDesc_AK0_M_K1
,
BGridDesc_BK0_N_K1
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
,
p_shared
,
karg
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
index_t
batch_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run_2Lds
<
AGridDesc_AK0_M_K1
,
BGridDesc_BK0_N_K1
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
,
p_shared_0
,
p_shared_1
,
karg
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
// namespace
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
AComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename
BComputeDataType
=
AComputeDataType
>
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
AComputeDataType
,
BComputeDataType
>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
bool
isMultiD
=
DsDataType
::
Size
()
>
0
;
static
constexpr
bool
isMultiABD
=
isMultiA
||
isMultiB
||
isMultiD
;
// multi ABD not supported
static_assert
(
!
isMultiABD
,
"Multi A, Mutli B and Multi D are not supported"
);
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
const
auto
M
=
in_gemmm_gemmk_desc
.
GetLength
(
I0
);
const
auto
K
=
in_gemmm_gemmk_desc
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
in_gemmm_gemmk_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
const
auto
N
=
wei_gemmn_gemmk_desc
.
GetLength
(
I0
);
const
auto
K
=
wei_gemmn_gemmk_desc
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
wei_gemmn_gemmk_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
return
out_gemmm_gemmn_desc
;
}
// desc for problem definition
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \
MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
AComputeDataType, BComputeDataType
// Use appropriate gridwise gemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
GridwiseGemmV3TemplateParams
>
;
static
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
index_t
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
index_t
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
return
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
,
GridwiseGemm
::
CalculateMBlock
(
M
),
GridwiseGemm
::
CalculateNBlock
(
N
));
}
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_as
,
const
void
*
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_a_grid_
{},
p_b_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_bk0_n_bk1_
{
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
// p_as and p_bs are pointers
p_a_grid_
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_b_grid_
=
static_cast
<
const
BDataType
*>
(
p_bs
);
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
void
Print
()
const
{
std
::
cout
<<
"A[AK0, M, AK1]: "
<<
a_grid_desc_ak0_m_ak1_
<<
std
::
endl
;
std
::
cout
<<
"B[BK0, N, BK1]: "
<<
b_grid_desc_bk0_n_bk1_
<<
std
::
endl
;
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// pointers (tuple if multi AB, pointer if no)
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
float
ave_time
=
0
;
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
const
index_t
GemmM
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
GemmK
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
GemmM
,
GemmN
,
I1
/*arg.KBatch*/
);
gdy
*=
arg
.
num_group_
;
index_t
K_split
=
(
GemmK
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_e_grid_
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
{
typename
GridwiseGemm
::
Argument
gemm_arg_
=
gemm_arg
;
ck
::
utility
::
RotatingMemWrapper
<
typename
GridwiseGemm
::
Argument
>
rotating_mem
(
gemm_arg_
,
stream_config
.
rotating_count
,
gemm_arg_
.
M
*
gemm_arg_
.
K
*
sizeof
(
ADataType
),
gemm_arg_
.
K
*
gemm_arg_
.
N
*
sizeof
(
BDataType
));
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
};
ave_time
+=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
gemm_arg_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
num_group_
);
}
else
{
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
gemm_arg
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
num_group_
);
}
};
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
const
auto
kernel
=
kernel_grouped_conv_fwd_xdl_cshuffle_v3
<
GridwiseGemm
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check ConvolutionForwardSpecialization
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
3
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
ConvStride
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
return
false
;
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
3
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
return
false
;
}
}
}
// check vector access of A
// FIXME: layout
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
G_NW_C
>
||
is_same_v
<
ALayout
,
ctc
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check vector access of B
// FIXME: layout
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
ctc
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
ctc
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
ctc
::
GKXC
>
||
is_same_v
<
BLayout
,
ctc
::
GKYXC
>
||
is_same_v
<
BLayout
,
ctc
::
GKZYXC
>
||
is_same_v
<
BLayout
,
ctc
::
KXGC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXGC
>
||
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check Gridwise GEMM
const
index_t
GemmM
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
GemmK
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
typename
GridwiseGemm
::
Argument
gemm_arg
{
nullptr
,
nullptr
,
nullptr
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
/*KBatch*/
};
return
GridwiseGemm
::
CheckValidity
(
gemm_arg
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_as
,
const
void
*
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
d39c3f5d
...
@@ -961,6 +961,29 @@ struct Elu
...
@@ -961,6 +961,29 @@ struct Elu
const
float
alpha_
;
const
float
alpha_
;
};
};
struct
ConvScale
{
__host__
__device__
ConvScale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
f8_t
>
(
c
*
scale_in_
*
scale_wei_
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
// support fastconvert of int8 to fp16
// support fastconvert of int8 to fp16
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
d39c3f5d
...
@@ -1123,7 +1123,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1123,7 +1123,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
template
<
typename
CGridDesc
>
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
...
@@ -1141,26 +1141,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1141,26 +1141,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template
<
bool
HasMainKBlockLoop
,
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
void
*
p_shared
,
const
Problem
&
problem
)
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
)
{
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1508,12 +1504,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1508,12 +1504,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared
,
void
*
p_shared_1
,
const
Problem
&
problem
)
const
Problem
&
problem
)
{
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
...
@@ -1521,11 +1516,42 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1521,11 +1516,42 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
Run
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
problem
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
}
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1879,6 +1905,43 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1879,6 +1905,43 @@ struct GridwiseGemm_xdl_cshuffle_v3
});
});
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
Run_2Lds
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_0
,
p_shared_1
,
problem
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
}
};
};
}
// namespace ck
}
// namespace ck
include/ck_tile/core.hpp
View file @
d39c3f5d
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/array.hpp"
...
@@ -47,10 +48,12 @@
...
@@ -47,10 +48,12 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
...
@@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i32"
);
// buffer store ui16
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_ui16
(
uint16_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_ui16x2
(
uint16x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_ui16x4
(
uint16x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i16"
);
CK_TILE_DEVICE_EXTERN
void
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i32x2
(
int32x2_t
vdata
,
llvm_amdgcn_raw_buffer_store_i32x2
(
int32x2_t
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
...
@@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
...
@@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
uint16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
// fp32
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
// fp32
...
@@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
...
@@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
}
}
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint16_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_ui16
(
bit_cast
<
uint16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_ui16x2
(
bit_cast
<
uint16x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_ui16x4
(
bit_cast
<
uint16x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_ui16x4
(
src_thread_data
.
template
get_as
<
uint16x4_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_ui16x4
(
src_thread_data
.
template
get_as
<
uint16x4_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
uint16_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
else
{
{
using
r_t
=
thread_buffer
<
int8_t
,
sizeof
(
T
)
*
N
>
;
using
r_t
=
thread_buffer
<
int8_t
,
sizeof
(
T
)
*
N
>
;
...
@@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
...
@@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
{
if
constexpr
(
N
==
2
)
if
constexpr
(
N
==
2
)
{
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
bit_cast
<
fp16_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
bit_cast
<
fp16
x2
_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
...
...
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
0 → 100644
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace
ck_tile
{
CK_TILE_HOST_DEVICE
bf16_t
add_bf16_t
(
const
bf16_t
&
a
,
const
bf16_t
&
b
)
{
return
type_convert
<
bf16_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
{
bf16x2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template
<
typename
X
>
CK_TILE_DEVICE
void
atomic_add
(
X
*
p_dst
,
const
X
&
x
);
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf16x2_t
>
(
bf16x2_t
*
p_dst
,
const
bf16x2_t
&
x
)
{
union
U32BF162_ADDR
{
uint32_t
*
u32_a
;
bf16x2_t
*
bf162_a
;
};
union
U32BF162
{
uint32_t
u32
;
bf16x2_t
bf162
;
};
U32BF162_ADDR
dword_addr
;
U32BF162
cur_v
;
U32BF162
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf162_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf162
=
add_bf16x2_t
(
cur_v
.
bf162
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
{
static_assert
((
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicAdd
(
p_dst
,
bit_cast
<
float
>
(
x
));
}
else
if
constexpr
(
N
==
2
)
{
atomicAdd
(
c_style_pointer_cast
<
float
*>
(
p_dst
),
x
.
template
get_as
<
float
>()[
I0
]);
atomicAdd
(
c_style_pointer_cast
<
float
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
float
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
atomicAdd
(
p_dst
,
bit_cast
<
double
>
(
x
));
}
else
if
constexpr
(
N
==
2
)
{
atomicAdd
(
c_style_pointer_cast
<
double
*>
(
p_dst
),
x
.
template
get_as
<
double
>()[
I0
]);
atomicAdd
(
c_style_pointer_cast
<
double
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
double
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicAdd
(
p_dst
,
bit_cast
<
int32_t
>
(
x
));
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicAdd
(
p_dst
,
bit_cast
<
uint32_t
>
(
x
));
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf16_t
>::
value
)
{
if
constexpr
(
N
==
2
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
),
bit_cast
<
bf16x2_t
>
(
x
));
}
else
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x2_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x2_t
>()[
I1
]);
}
}
}
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
atomic_max_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
{
static_assert
((
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
)),
"wrong! not implemented"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicMax
(
p_dst
,
bit_cast
<
float
>
(
x
));
}
else
if
constexpr
(
N
==
2
)
{
atomicMax
(
c_style_pointer_cast
<
float
*>
(
p_dst
),
x
.
template
get_as
<
float
>()[
I0
]);
atomicMax
(
c_style_pointer_cast
<
float
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
float
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicMax
(
p_dst
,
bit_cast
<
double
>
(
x
));
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicMax
(
p_dst
,
bit_cast
<
int32_t
>
(
x
));
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
atomicMax
(
p_dst
,
bit_cast
<
uint32_t
>
(
x
));
}
}
}
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
...
@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
// ui8
// using uint8_t
using
uint8x2_t
=
uint8_t
__attribute
((
ext_vector_type
(
2
)));
using
uint8x4_t
=
uint8_t
__attribute
((
ext_vector_type
(
4
)));
using
uint8x8_t
=
uint8_t
__attribute
((
ext_vector_type
(
8
)));
using
uint8x16_t
=
uint8_t
__attribute
((
ext_vector_type
(
16
)));
using
uint8x32_t
=
uint8_t
__attribute
((
ext_vector_type
(
32
)));
using
uint8x64_t
=
uint8_t
__attribute
((
ext_vector_type
(
64
)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// f8
// using fp8_t
// using fp8_t
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
d39c3f5d
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
...
@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
...
@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
bool
constexpr
use_amd_buffer_addressing
=
false
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
}
...
@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
{
{
if
(
is_valid_element
)
if
(
is_valid_element
)
{
{
atomic_add
<
X
>
(
c_style_pointer_cast
<
X
*
>
(
&
p_data_
[
i
]
)
,
x
);
atomic_add
_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
}
}
}
}
}
}
...
@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
...
@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
bool
constexpr
use_amd_buffer_addressing
=
false
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
}
else
if
(
is_valid_element
)
else
if
(
is_valid_element
)
{
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*
>
(
&
p_data_
[
i
]
)
,
x
);
atomic_max
_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
}
}
}
}
...
...
include/ck_tile/core/tensor/store_tile.hpp
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
d39c3f5d
...
@@ -16,7 +16,9 @@
...
@@ -16,7 +16,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
BufferView_
,
typename
TensorDesc_
>
template
<
typename
BufferView_
,
typename
TensorDesc_
,
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
struct
tensor_view
struct
tensor_view
{
{
using
buffer_view
=
remove_reference_t
<
BufferView_
>
;
using
buffer_view
=
remove_reference_t
<
BufferView_
>
;
...
@@ -24,6 +26,7 @@ struct tensor_view
...
@@ -24,6 +26,7 @@ struct tensor_view
using
TensorDesc
=
remove_cvref_t
<
TensorDesc_
>
;
using
TensorDesc
=
remove_cvref_t
<
TensorDesc_
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
static
constexpr
auto
DstInMemOp
=
DstInMemOp_
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
...
@@ -140,6 +143,23 @@ struct tensor_view
...
@@ -140,6 +143,23 @@ struct tensor_view
x
);
x
);
}
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
{
{
printf
(
"tensor_view{"
);
printf
(
"tensor_view{"
);
...
@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
...
@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
}
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
,
typename
DataType
,
typename
DataType
,
typename
...
Lengths
,
typename
...
Lengths
,
typename
...
Strides
,
typename
...
Strides
,
...
@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
...
@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
>
{
buffer_view
,
desc
};
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
,
DstInMemOp
>
{
buffer_view
,
desc
};
}
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
...
@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
...
@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
NewLowerDimensionOldVisibleIdss
{},
NewLowerDimensionOldVisibleIdss
{},
NewUpperDimensionNewVisibleIdss
{});
NewUpperDimensionNewVisibleIdss
{});
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
remove_cvref_t
<
decltype
(
new_desc
)
>>
{
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
old_tensor_view
.
buf_
,
new_desc
};
remove_cvref_t
<
decltype
(
new_desc
)
>
,
remove_cvref_t
<
OldTensorView
>::
DstInMemOp
>
{
old_tensor_view
.
buf_
,
new_desc
};
}
}
template
<
typename
TensorView
,
template
<
typename
TensorView
,
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
d39c3f5d
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
d39c3f5d
...
@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
...
@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// [x0', x1', ... ] ==> [offset]
// also move window-origin
// also move window-origin
...
...
include/ck_tile/core/tensor/update_tile.hpp
0 → 100644
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
update
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
update
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/philox_rand.hpp
0 → 100644
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
class
philox
{
public:
CK_TILE_HOST_DEVICE
philox
(
unsigned
long
long
seed_
,
unsigned
long
long
offset_
)
:
seed
(
reinterpret_cast
<
const
uint2
&>
(
seed_
))
{
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset_
;
}
CK_TILE_HOST_DEVICE
uint4
get_philox_4x32
(
const
unsigned
long
long
subsequence
)
const
{
uint4
counter_
=
counter
;
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter_
);
tmp
->
y
=
subsequence
;
uint2
key_
=
seed
;
// 7-round philox
#pragma unroll
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter_
=
philox_single_round
(
counter_
,
key_
);
key_
.
x
+=
kPhilox10A
;
key_
.
y
+=
kPhilox10B
;
}
uint4
output
=
philox_single_round
(
counter_
,
key_
);
return
output
;
}
CK_TILE_HOST_DEVICE
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
private:
struct
ull2
{
uint64_t
x
;
uint64_t
y
;
};
uint4
counter
;
const
uint2
seed
;
CK_TILE_HOST_DEVICE
uint2
mulhilo32
(
const
unsigned
int
a
,
const
unsigned
int
b
)
const
{
uint2
*
res
;
unsigned
long
long
tmp
;
tmp
=
static_cast
<
unsigned
long
long
>
(
a
)
*
b
;
res
=
reinterpret_cast
<
uint2
*>
(
&
tmp
);
return
*
res
;
}
CK_TILE_HOST_DEVICE
uint4
philox_single_round
(
const
uint4
ctr
,
const
uint2
key
)
const
{
uint2
res0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
);
uint2
res1
=
mulhilo32
(
kPhiloxSB
,
ctr
.
z
);
uint4
ret
=
{
res1
.
y
^
ctr
.
y
^
key
.
x
,
res1
.
x
,
res0
.
y
^
ctr
.
w
^
key
.
y
,
res0
.
x
};
return
ret
;
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
static
const
unsigned
long
kPhiloxSA
=
0xD2511F53
;
static
const
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
};
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
d39c3f5d
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
...
...
include/ck_tile/host/host_tensor.hpp
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
...
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
}
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
G
et
S
trides
()
const
{
return
mStrides
;
}
const
std
::
vector
<
std
::
size_t
>&
g
et
_s
trides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
...
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
...
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
for
(
std
::
size_t
i
=
0
;
i
<
a
.
get_num_of_dimension
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
a
.
get_num_of_dimension
();
i
++
)
{
{
new_lengths
[
i
]
=
a
.
get_lengths
()[
new2old
[
i
]];
new_lengths
[
i
]
=
a
.
get_lengths
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
G
et
S
trides
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
g
et
_s
trides
()[
new2old
[
i
]];
}
}
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
...
@@ -327,7 +327,7 @@ struct HostTensor
...
@@ -327,7 +327,7 @@ struct HostTensor
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
G
et
S
trides
()
const
{
return
mDesc
.
G
et
S
trides
();
}
decltype
(
auto
)
g
et
_s
trides
()
const
{
return
mDesc
.
g
et
_s
trides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
@@ -481,6 +481,34 @@ struct HostTensor
...
@@ -481,6 +481,34 @@ struct HostTensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
const
{
if
(
axes
.
empty
())
{
axes
.
resize
(
this
->
get_num_of_dimension
());
std
::
iota
(
axes
.
rbegin
(),
axes
.
rend
(),
0
);
}
if
(
axes
.
size
()
!=
mDesc
.
get_num_of_dimension
())
{
throw
std
::
runtime_error
(
"HostTensor::transpose(): size of axes must match tensor dimension"
);
}
std
::
vector
<
size_t
>
tlengths
,
tstrides
;
for
(
const
auto
&
axis
:
axes
)
{
tlengths
.
push_back
(
get_lengths
()[
axis
]);
tstrides
.
push_back
(
get_strides
()[
axis
]);
}
HostTensor
<
T
>
ret
(
*
this
);
ret
.
mDesc
=
HostTensorDescriptor
(
tlengths
,
tstrides
);
return
ret
;
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
{
return
const_cast
<
HostTensor
<
T
>
const
*>
(
this
)
->
transpose
(
axes
);
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
...
...
include/ck_tile/host/reference/reference_batched_dropout.hpp
0 → 100644
View file @
d39c3f5d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
RandValOutputDataType
>
CK_TILE_HOST
void
reference_batched_dropout
(
HostTensor
<
DataType
>&
in_out_b_m_n
,
const
HostTensor
<
RandValOutputDataType
>&
randval_b_m_n
,
const
uint8_t
&
p_undrop_in_uint8_t
,
const
float
scale
)
{
const
int
N
=
in_out_b_m_n
.
mDesc
.
get_lengths
()[
2
];
auto
f
=
[
&
](
auto
batch
,
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
tmp
=
ck_tile
::
type_convert
<
float
>
(
in_out_b_m_n
(
batch
,
m
,
n
))
*
scale
;
in_out_b_m_n
(
batch
,
m
,
n
)
=
randval_b_m_n
(
batch
,
m
,
n
)
<=
p_undrop_in_uint8_t
?
ck_tile
::
type_convert
<
DataType
>
(
tmp
)
:
DataType
(
0
);
}
};
make_ParallelTensorFunctor
(
f
,
randval_b_m_n
.
mDesc
.
get_lengths
()[
0
],
randval_b_m_n
.
mDesc
.
get_lengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment