Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b134b7d6
Commit
b134b7d6
authored
May 16, 2022
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into cpu_avx2
parents
090ba885
9f71ff48
Changes
211
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
575 additions
and
359 deletions
+575
-359
include/ck/tensor_operation/cpu/device/device_base_cpu.hpp
include/ck/tensor_operation/cpu/device/device_base_cpu.hpp
+2
-1
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+7
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+252
-9
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+2
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+19
-22
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+24
-27
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
+29
-32
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
+33
-36
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+9
-6
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+82
-85
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+16
-15
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+26
-46
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+10
-10
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+9
-9
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+10
-13
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+9
-9
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+11
-11
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
.../gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
+7
-6
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+8
-7
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
+10
-10
No files found.
include/ck/tensor_operation/cpu/device/device_base_cpu.hpp
View file @
b134b7d6
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define DEVICE_BASE_CPU_HPP
#define DEVICE_BASE_CPU_HPP
#include <string>
#include <string>
#include "stream_config.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -23,7 +24,7 @@ struct BaseInvoker
...
@@ -23,7 +24,7 @@ struct BaseInvoker
BaseInvoker
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
&
operator
=
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
&
operator
=
(
const
BaseInvoker
&
)
=
default
;
virtual
float
Run
(
const
BaseArgument
*
,
int
=
1
)
=
0
;
virtual
float
Run
(
const
BaseArgument
*
,
const
StreamConfig
&
=
StreamConfig
{},
int
=
1
)
=
0
;
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
...
@@ -690,7 +690,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -690,7 +690,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
...
@@ -743,9 +745,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -743,9 +745,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
b134b7d6
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#pragma once
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "thread_group.hpp"
namespace
ck
{
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
...
@@ -25,7 +39,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -25,7 +39,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
WaveSize
=
64
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
...
@@ -55,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -55,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__device__
static
auto
GetWaveIdx
()
__device__
static
auto
GetWaveIdx
()
{
{
const
index_t
thread_id
=
get_t
hread
_
loc
al_1d_i
d
();
const
index_t
thread_id
=
ThisT
hread
B
loc
k
::
GetThreadI
d
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
...
@@ -122,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -122,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"
BlockSize
!= MWaves * NWaves * WaveSize
\n
"
);
"
ThisThreadBlock::GetNumOfThread()
!= MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -301,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -301,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
});
}
}
pr
ivate
:
pr
otected
:
// A[M0, M1, M2, KPerThread]
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
...
@@ -338,5 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -338,5 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
};
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
A_K1
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
B_K1
;
using
Base
::
c_thread_buf_
;
using
Base
::
c_thread_desc_
;
using
Base
::
CalculateAThreadOriginDataIndex
;
using
Base
::
CalculateBThreadOriginDataIndex
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
// 2-wave optimized blockwise gemm
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
k
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
k
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
__builtin_amdgcn_sched_barrier
();
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
asm
volatile
(
"s_barrier"
::
);
__builtin_amdgcn_sched_barrier
();
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if
constexpr
(
k
.
value
==
KPerThread
-
KPerInnerLoop
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
();
block_sync_lds
();
__builtin_amdgcn_sched_barrier
();
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
();
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
();
}
});
});
});
__builtin_amdgcn_sched_barrier
();
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
();
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
// B[N0, N1, N2, KPerInnerLoop]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
LoopScheduler
LoopSched
>
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
View file @
b134b7d6
...
@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1
...
@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
{
static_assert
(
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
SrcDesc
>
>
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
remove_
cv
ref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
DstDesc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
...
...
include/ck/tensor_operation/gpu/block/
blockwise
_tensor_slice_transfer_v4r1.hpp
→
include/ck/tensor_operation/gpu/block/
thread_group
_tensor_slice_transfer_v4r1.hpp
View file @
b134b7d6
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
...
@@ -13,7 +11,7 @@ namespace ck {
...
@@ -13,7 +11,7 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
InMemoryDataOperationEnum
DstInMemOp
,
...
@@ -35,7 +33,7 @@ template <index_t BlockSize,
...
@@ -35,7 +33,7 @@ template <index_t BlockSize,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
NumThreadScratch
=
1
>
index_t
NumThreadScratch
=
1
>
struct
Blockwise
TensorSliceTransfer_v4r1
struct
ThreadGroup
TensorSliceTransfer_v4r1
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
...
@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
TensorSliceTransfer_v4r1
(
__device__
constexpr
ThreadGroup
TensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
SrcElementwiseOperation
&
src_element_op
,
...
@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
dst_element_op
)
dst_element_op
)
{
{
static_assert
(
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
SrcDesc
>
>
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
remove_
cv
ref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
DstDesc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
...
@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong!
BlockSize
too small"
);
"wrong!
ThreadGroup::GetNumOfThread()
too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_i
d
()));
make_multi_index
(
ThreadGroup
::
GetThreadI
d
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
...
@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
}
}
...
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
}
...
@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
...
@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
...
@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/
blockwise
_tensor_slice_transfer_v6r1.hpp
→
include/ck/tensor_operation/gpu/block/
thread_group
_tensor_slice_transfer_v6r1.hpp
View file @
b134b7d6
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
...
@@ -13,10 +11,10 @@ namespace ck {
...
@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
Block
SliceLengths
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
SrcData
,
...
@@ -28,19 +26,19 @@ template <index_t BlockSize,
...
@@ -28,19 +26,19 @@ template <index_t BlockSize,
index_t
ScalarPerVector
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
TensorSliceTransfer_v6r1
struct
ThreadGroup
TensorSliceTransfer_v6r1
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
Block
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
TensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadGroup
TensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
dst_desc
,
...
@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1
...
@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1
element_op
)
element_op
)
{
{
static_assert
(
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
SrcDesc
>
>
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
remove_
cv
ref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
DstDesc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
"wrong! nDim not consistent"
);
static_assert
(
static_assert
(
is_same
<
Block
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong!
BlockSize
too small"
);
"wrong!
ThreadGroup::GetNumOfThread()
too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_i
d
()));
make_multi_index
(
ThreadGroup
::
GetThreadI
d
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
...
@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
...
@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
Run
(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
threadwise_transfer_
.
Run
(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
...
@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
...
@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
...
@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
...
@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
...
@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1
...
@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/
blockwise
_tensor_slice_transfer_v6r2.hpp
→
include/ck/tensor_operation/gpu/block/
thread_group
_tensor_slice_transfer_v6r2.hpp
View file @
b134b7d6
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
...
@@ -13,10 +11,10 @@ namespace ck {
...
@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
// 3. Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
Block
SliceLengths
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src0Data
,
...
@@ -31,21 +29,21 @@ template <index_t BlockSize,
...
@@ -31,21 +29,21 @@ template <index_t BlockSize,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
TensorSliceTransfer_v6r2
struct
ThreadGroup
TensorSliceTransfer_v6r2
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
Block
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
TensorSliceTransfer_v6r2
(
const
Src0Desc
&
src0_desc
,
__device__
constexpr
ThreadGroup
TensorSliceTransfer_v6r2
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
Index
&
src1_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src0_desc
,
:
threadwise_transfer_
(
src0_desc
,
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
(),
src1_desc
,
src1_desc
,
...
@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2
...
@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2
element_op
)
element_op
)
{
{
static_assert
(
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
Src0Desc
>
>
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
remove_
cv
ref_t
<
Src0Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
Src1Desc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
Src1Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
DstDesc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
"wrong! nDim not consistent"
);
static_assert
(
static_assert
(
is_same
<
Block
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong!
BlockSize
too small"
);
"wrong!
ThreadGroup::GetNumOfThread()
too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_i
d
()));
make_multi_index
(
ThreadGroup
::
GetThreadI
d
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
...
@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
...
@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
dst_desc
,
dst_buf
);
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
dst_desc
,
dst_buf
);
}
}
...
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
...
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
...
@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
...
@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
...
@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
...
@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
...
@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
...
@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/
blockwise
_tensor_slice_transfer_v6r3.hpp
→
include/ck/tensor_operation/gpu/block/
thread_group
_tensor_slice_transfer_v6r3.hpp
View file @
b134b7d6
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
...
@@ -13,10 +11,10 @@ namespace ck {
...
@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
Block
SliceLengths
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src0Data
,
...
@@ -34,23 +32,23 @@ template <index_t BlockSize,
...
@@ -34,23 +32,23 @@ template <index_t BlockSize,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferSrc2ResetCoordinateAfterRun
,
bool
ThreadTransferSrc2ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
TensorSliceTransfer_v6r3
struct
ThreadGroup
TensorSliceTransfer_v6r3
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
Block
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
TensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
__device__
constexpr
ThreadGroup
TensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
Index
&
src1_block_slice_origin
,
const
Src2Desc
&
src2_desc
,
const
Src2Desc
&
src2_desc
,
const
Index
&
src2_block_slice_origin
,
const
Index
&
src2_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src0_desc
,
:
threadwise_transfer_
(
src0_desc
,
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
(),
src1_desc
,
src1_desc
,
...
@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3
element_op
)
element_op
)
{
{
static_assert
(
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
Src0Desc
>
>
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
remove_
cv
ref_t
<
Src0Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
Src1Desc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
Src1Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
Src2Desc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
Src2Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_ref
erence_t
<
remove_cv
_t
<
DstDesc
>
>
::
GetNumOfDimension
()
&&
nDim
==
remove_
cv
ref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
"wrong! nDim not consistent"
);
static_assert
(
static_assert
(
is_same
<
Block
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong!
BlockSize
too small"
);
"wrong!
ThreadGroup::GetNumOfThread()
too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
make_multi_index
(
get_thread_local_1d_id
()));
...
@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
Run
(
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
src2_desc
,
src2_buf
,
dst_desc
,
dst_buf
);
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
src2_desc
,
src2_buf
,
dst_desc
,
dst_buf
);
...
@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
...
@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
...
@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__
void
MoveSrc2SliceWindow
(
const
Src2Desc
&
src2_desc
,
const
Index
&
step
)
__device__
void
MoveSrc2SliceWindow
(
const
Src2Desc
&
src2_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrc2SliceWindow
(
src2_desc
,
step
);
threadwise_transfer_
.
MoveSrc2SliceWindow
(
src2_desc
,
step
);
}
}
...
@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_i
d
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadI
d
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
...
@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
...
@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
b134b7d6
#ifndef DEVICE_BASE_HPP
#pragma once
#define DEVICE_BASE_HPP
#include <string>
#include <string>
#include "stream_config.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -22,7 +23,10 @@ struct BaseInvoker
...
@@ -22,7 +23,10 @@ struct BaseInvoker
BaseInvoker
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
&
operator
=
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
&
operator
=
(
const
BaseInvoker
&
)
=
default
;
virtual
float
Run
(
const
BaseArgument
*
,
int
=
1
)
=
0
;
virtual
float
Run
(
const
BaseArgument
*
,
const
StreamConfig
&
=
StreamConfig
{})
{
return
float
{
0
};
}
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
...
@@ -33,8 +37,8 @@ struct BaseOperator
...
@@ -33,8 +37,8 @@ struct BaseOperator
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
=
0
;
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
=
0
;
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
@@ -42,4 +46,3 @@ struct BaseOperator
...
@@ -42,4 +46,3 @@ struct BaseOperator
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
b134b7d6
...
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
...
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ElementwiseOperation
,
typename
D1ReduceOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -44,8 +43,7 @@ __global__ void
...
@@ -44,8 +43,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D0ReduceOperation
d0_reduce_op
,
const
D1ElementwiseOperation
d1_element_op
,
const
D1ReduceOperation
d1_reduce_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -82,8 +80,7 @@ __global__ void
...
@@ -82,8 +80,7 @@ __global__ void
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d0_reduce_op
,
d1_element_op
,
d1_reduce_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -99,8 +96,7 @@ __global__ void
...
@@ -99,8 +96,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
d0_reduce_op
;
ignore
=
d1_element_op
;
ignore
=
d1_reduce_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
@@ -110,6 +106,9 @@ __global__ void
...
@@ -110,6 +106,9 @@ __global__ void
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
}
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
@@ -125,6 +124,7 @@ template <typename ALayout,
...
@@ -125,6 +124,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -157,12 +157,12 @@ template <typename ALayout,
...
@@ -157,12 +157,12 @@ template <typename ALayout,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ElementwiseOperation
>
D1ReduceOperation
>
{
{
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
...
@@ -564,6 +564,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -564,6 +564,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -603,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -603,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
;
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopSched
>
;
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
1
,
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
1
,
CGridDesc_M_N
{},
1
,
1
));
...
@@ -624,8 +626,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -624,8 +626,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
D1ReduceOperation
d1_reduce_op
,
index_t
BatchCount
)
index_t
BatchCount
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
...
@@ -639,17 +640,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -639,17 +640,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
d_grid_desc_m_
{
DeviceOp
::
MakeDGridDescriptor_M
(
MRaw
)},
d_grid_desc_m_
{
DeviceOp
::
MakeDGridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d_grid_desc_mblock_mperblock_
{},
d_grid_desc_mblock_mperblock_
{},
compute_base_ptr_of_batch_
{
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
(),
compute_base_ptr_of_batch_
{
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
(),
type_convert
<
index_t
>
(
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()),
c_grid_desc_m_n_
.
GetElementSpaceSize
(),
type_convert
<
index_t
>
(
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()),
d_grid_desc_m_
.
GetElementSpaceSize
(),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
()),
d_grid_desc_m_
.
GetElementSpaceSize
()},
type_convert
<
index_t
>
(
d_grid_desc_m_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
d_grid_desc_m_
.
GetElementSpaceSize
())},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
d0_reduce_op_
{
d0_reduce_op
},
d1_element_op_
{
d1_element_op
}
d1_reduce_op_
{
d1_reduce_op
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
...
@@ -684,8 +685,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -684,8 +685,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
D0ReduceOperation
d0_reduce_op_
;
D1ElementwiseOperation
d1_element_op_
;
D1ReduceOperation
d1_reduce_op_
;
};
};
// Invoker
// Invoker
...
@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
/* nrepeat */
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
#if 0
#if 0
{
{
...
@@ -726,11 +726,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -726,11 +726,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
float
elapsed_time
=
0.0
f
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_batched_gemm_reduce_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_gemm_reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -740,8 +740,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -740,8 +740,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ElementwiseOperation
,
D1ReduceOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -750,27 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -750,27 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t
<
Block2CTileMap
>
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
true
>
;
launch_kernel
(
kernel
,
elapsed_time
=
dim3
(
grid_size
),
launch_and_time_kernel
(
stream_config
,
dim3
(
BlockSize
),
kernel
,
0
,
dim3
(
grid_size
),
arg
.
p_a_grid_
,
dim3
(
BlockSize
),
arg
.
p_b_grid_
,
0
,
arg
.
p_c_grid_
,
arg
.
p_a_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d1_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
p_d0_grid_
,
arg
.
a_element_op_
,
arg
.
p_d1_grid_
,
arg
.
b_element_op_
,
arg
.
BatchCount_
,
arg
.
c_element_op_
,
arg
.
a_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
b_element_op_
,
arg
.
d1_reduce_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
d1_element_op_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
...
@@ -782,8 +782,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -782,8 +782,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ElementwiseOperation
,
D1ReduceOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -792,36 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -792,36 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t
<
Block2CTileMap
>
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
false
>
;
launch_kernel
(
kernel
,
elapsed_time
=
dim3
(
grid_size
),
launch_and_time_kernel
(
stream_config
,
dim3
(
BlockSize
),
kernel
,
0
,
dim3
(
grid_size
),
arg
.
p_a_grid_
,
dim3
(
BlockSize
),
arg
.
p_b_grid_
,
0
,
arg
.
p_c_grid_
,
arg
.
p_a_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d1_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
p_d0_grid_
,
arg
.
a_element_op_
,
arg
.
p_d1_grid_
,
arg
.
b_element_op_
,
arg
.
BatchCount_
,
arg
.
c_element_op_
,
arg
.
a_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
b_element_op_
,
arg
.
d1_reduce_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
d1_element_op_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
return
0
;
return
elapsed_time
;
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -865,8 +866,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -865,8 +866,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
D1ReduceOperation
d1_reduce_op
,
index_t
BatchCount
)
index_t
BatchCount
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
...
@@ -883,8 +883,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -883,8 +883,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d0_reduce_op
,
d1_element_op
,
d1_reduce_op
,
BatchCount
};
BatchCount
};
}
}
...
@@ -905,8 +904,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -905,8 +904,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
D1ReduceOperation
d1_reduce_op
,
index_t
BatchCount
)
override
index_t
BatchCount
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
@@ -923,8 +921,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -923,8 +921,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d0_reduce_op
,
d1_element_op
,
d1_reduce_op
,
BatchCount
);
BatchCount
);
}
}
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
b134b7d6
...
@@ -107,7 +107,7 @@ __global__ void
...
@@ -107,7 +107,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
compute_
base_ptr
_of_batch
_
;
ignore
=
compute_
ptr_offset
_of_batch
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
...
@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
compute_ptr_offset_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
compute_ptr_offset_of_batch_
{
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
type_convert
<
index_t
>
(
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
()),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
type_convert
<
index_t
>
(
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
())},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -427,7 +428,7 @@ struct DeviceBatchedGemmXdl
...
@@ -427,7 +428,7 @@ struct DeviceBatchedGemmXdl
{
{
using
Argument
=
DeviceBatchedGemmXdl
::
Argument
;
using
Argument
=
DeviceBatchedGemmXdl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
...
@@ -455,13 +456,12 @@ struct DeviceBatchedGemmXdl
...
@@ -455,13 +456,12 @@ struct DeviceBatchedGemmXdl
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
{
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
...
@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t
<
Block2CTileMap
>
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
...
@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t
<
Block2CTileMap
>
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
...
@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
...
@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
ShowInfo
(
arg
);
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
>
0
)
hipGetErrorString
(
hipMemset
(
{
arg
.
p_c_grid_
,
ave_time
=
0
,
launch_and_time_kernel
(
kernel
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
nrepeat
,
sizeof
(
CDataType
)));
dim3
(
grid_size
),
dim3
(
BlockSize
),
launch_and_time_kernel
(
stream_config
,
0
,
kernel
,
arg
.
p_a_grid_
,
dim3
(
grid_size
),
arg
.
p_b_grid_
,
dim3
(
BlockSize
),
arg
.
p_c_grid_
,
0
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
p_a_grid_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
p_b_grid_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_element_op_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_element_op_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
a_element_op_
,
}
arg
.
b_element_op_
,
arg
.
c_element_op_
,
if
(
kbatch
>
1
||
nrepeat
<=
0
)
arg
.
block_2_ctile_map_
);
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
...
@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
...
@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
float
ave_time
=
0
;
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
...
@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
has_main_k0_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -603,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -603,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
true
>
;
true
>
;
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -636,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -636,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
false
>
;
false
>
;
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -656,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -656,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -698,7 +698,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -698,7 +698,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
// Gridwise GEMM size
// Gridwise GEMM size
for
(
in
t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
...
@@ -642,7 +642,7 @@ struct
...
@@ -642,7 +642,7 @@ struct
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
#if 0
#if 0
{
{
...
@@ -698,13 +698,12 @@ struct
...
@@ -698,13 +698,12 @@ struct
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r3
<
const
auto
kernel
=
kernel_gemm_xdlops_v3r3
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -728,8 +727,8 @@ struct
...
@@ -728,8 +727,8 @@ struct
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -772,8 +771,8 @@ struct
...
@@ -772,8 +771,8 @@ struct
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -796,9 +795,10 @@ struct
...
@@ -796,9 +795,10 @@ struct
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#pragma once
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device.hpp"
#include "device.hpp"
...
@@ -607,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -607,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
#if 0
#if 0
{
{
...
@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r2
<
const
auto
kernel
=
kernel_gemm_xdlops_v3r2
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -687,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -687,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -726,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -726,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -748,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -748,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -919,4 +917,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -919,4 +917,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
...
@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
#if 0
#if 0
{
{
...
@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -664,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -664,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -698,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -698,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -718,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -718,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
b134b7d6
...
@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
#if 0
#if 0
{
{
...
@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -499,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -499,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -530,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -530,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -550,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -550,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
View file @
b134b7d6
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
b134b7d6
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
b134b7d6
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment