Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
df6dd915
Commit
df6dd915
authored
Apr 16, 2020
by
Jing Zhang
Browse files
formating
parent
e9f05865
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
991 additions
and
472 deletions
+991
-472
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+0
-2
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+18
-8
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+19
-18
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+40
-39
composable_kernel/include/utility/amd_xdlops_emulate.hpp
composable_kernel/include/utility/amd_xdlops_emulate.hpp
+64
-33
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+0
-1
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+70
-71
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+217
-218
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+74
-74
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+489
-8
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
...
@@ -158,7 +158,5 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
df6dd915
...
...
@@ -26,14 +26,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
index_t
col
;
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
// static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave,
// GemmDataPerReadA, GemmDataPerReadB>{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetOutputLayout
();
}
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
GetOutputLayout
();
}
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
...
...
@@ -67,7 +72,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
template
Run
<
M
,
N
,
K
>(
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
}
...
...
@@ -76,7 +82,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetBeginOfThreadBlk
(
i
);
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
waveId
%
GemmNWaves
*
GemmNPerWave
+
thread_mtx_on_blk
.
col
;
...
...
@@ -94,14 +102,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__
void
XdlopsMatrixCSetZero
()
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
}
template
<
class
FloatC
>
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
}
};
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
df6dd915
...
...
@@ -117,7 +117,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
mThreadwiseStore
.
Run_optimized_dst_address_calculation
(
p_thread_buffer
,
p_block_dst
);
mThreadwiseStore
.
Run_optimized_dst_address_calculation
(
p_thread_buffer
,
p_block_dst
);
}
else
{
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
df6dd915
...
...
@@ -497,174 +497,171 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
}
};
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
>
__device__
constexpr
auto
GetMFMAInfo
();
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
>
__device__
constexpr
auto
GetMFMAInfo
();
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
8
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
4
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
4
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
8
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
32
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
16
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
4
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
8
,
64
>
()
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
}
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
,
...
...
@@ -685,7 +682,10 @@ struct XdlopsGemm_t
__device__
static
constexpr
index_t
M0
()
{
return
M0_
;
}
__device__
static
constexpr
index_t
N1
()
{
return
N1_
;
}
__device__
static
constexpr
index_t
N0
()
{
return
N0_
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
().
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
().
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
...
...
@@ -726,7 +726,6 @@ struct XdlopsGemm_t
return
mfma_type
.
num_output_blks
==
1
&&
mfma_type
.
num_input_blks
!=
1
;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
...
@@ -843,7 +842,8 @@ struct XdlopsGemm_t
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}).
Else
([
&
](
auto
)
{
...
...
@@ -852,7 +852,8 @@ struct XdlopsGemm_t
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
});
...
...
@@ -898,7 +899,7 @@ struct XdlopsGemm_t
__device__
void
SetZeroXdlopsRegs
(
Number
<
Size
>
)
const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
//gcnasm_accvgpr_zero<Size>();
//
gcnasm_accvgpr_zero<Size>();
#endif
}
...
...
@@ -907,8 +908,8 @@ struct XdlopsGemm_t
{
#if !CK_USE_AMD_XDLOPS_EMULATE
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
//gcnasm_nop<mfma_type.cycles>();
//gcnasm_accvgpr_read<Size>(p_c_thread);
//
gcnasm_nop<mfma_type.cycles>();
//
gcnasm_accvgpr_read<Size>(p_c_thread);
#else
(
void
)
p_c_thread
;
#endif
...
...
composable_kernel/include/utility/amd_xdlops_emulate.hpp
View file @
df6dd915
...
...
@@ -7,7 +7,8 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
const
float
&
,
const
float
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
32
;
i
++
)
...
...
@@ -17,7 +18,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
...
...
@@ -27,7 +29,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
...
...
@@ -53,12 +56,14 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_16x16x1f32
(
const
float
&
,
const
float
&
,
float16_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
...
...
@@ -66,66 +71,77 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_4x4x1f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
const
half4_t
&
,
float32_t
*
);
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
const
half4_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
...
...
@@ -133,54 +149,69 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
const
ushort2_t
&
,
const
ushort2_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
// clang-format on
}
#endif
composable_kernel/include/utility/common_header.hpp
View file @
df6dd915
...
...
@@ -31,5 +31,4 @@
#include "amd_xdlops_emulate.hpp"
#endif
#endif
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
...
@@ -858,7 +858,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
cudaDeviceSynchronize
();
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
...
@@ -1048,7 +1048,6 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
cudaDeviceSynchronize
();
...
...
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
...
@@ -85,7 +85,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
<
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -161,7 +162,6 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
cudaDeviceSynchronize
();
...
...
driver/src/conv_driver.cpp
View file @
df6dd915
...
...
@@ -20,26 +20,495 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
160
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
96
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
320
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 0
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
224
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
224
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 1
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
299
;
constexpr
index_t
WI
=
299
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
149
;
constexpr
index_t
WI
=
149
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 17x17, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x3, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
448
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
1
>
;
using
RightPads
=
Sequence
<
0
,
1
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 1
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x7, 230x230 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
230
;
constexpr
index_t
WI
=
230
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride = 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#endif
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
...
...
@@ -133,6 +602,18 @@ int main(int argc, char* argv[])
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment