Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
583755a7
Commit
583755a7
authored
Jul 29, 2019
by
Tejash Shah
Browse files
Add fp16 support in implicit gemm
parent
c15ff3c8
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
924 additions
and
366 deletions
+924
-366
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+10
-23
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
...ude/tensor_description/ConstantMergedTensorDescriptor.hpp
+4
-4
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+13
-15
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+80
-43
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+57
-60
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+76
-11
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+25
-8
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+17
-18
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+52
-0
composable_kernel/include/utility/bfloat16_dev.hpp
composable_kernel/include/utility/bfloat16_dev.hpp
+125
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+6
-0
composable_kernel/include/utility/float_types.h
composable_kernel/include/utility/float_types.h
+111
-0
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+14
-48
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+10
-6
composable_kernel/include/utility/vector_type.hpp
composable_kernel/include/utility/vector_type.hpp
+119
-8
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...de/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+181
-110
driver/src/driver.cpp
driver/src/driver.cpp
+24
-12
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
583755a7
...
@@ -15,6 +15,7 @@ namespace ck {
...
@@ -15,6 +15,7 @@ namespace ck {
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Float
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
OutGlobalDesc
,
...
@@ -50,9 +51,10 @@ template <index_t GridSize,
...
@@ -50,9 +51,10 @@ template <index_t GridSize,
index_t
WeiBlockCopyDstDataPerWrite_K
>
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
const
Float
*
const
__restrict__
p_wei_global
,
Run
(
const
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_out_global
)
const
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
{
// this is a mess
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
// TODO: find more elegent way of specifying (or calculating) performance parameters
...
@@ -84,12 +86,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -84,12 +86,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
...
@@ -98,14 +94,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -98,14 +94,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
,
"wrong! global vector load of input tensor is wrong"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
"wrong! cannot divide work evenly among block"
);
...
@@ -125,15 +113,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -125,15 +113,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
H
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
s
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
W
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
s
::
Get
(
I1
)
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
H
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
s
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
W
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
s
::
Get
(
I1
)
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...
@@ -260,7 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -260,7 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
AccDataType
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
...
@@ -332,7 +320,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -332,7 +320,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
...
...
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
View file @
583755a7
...
@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return
OriginalTensorDesc
{};
return
OriginalTensorDesc
{};
}
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
...
@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
IDim
>
)
{
{
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
...
@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
IDim
>
)
{
{
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
"wrong! stride of a merged dimension is undefined"
);
"wrong! stride of a merged dimension is undefined"
);
...
@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
}
}
__host__
__device__
static
constexpr
auto
GetElementSize
()
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
{
return
OriginalTensorDesc
::
GetElementSize
();
return
OriginalTensorDesc
::
GetElementSize
();
}
}
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
583755a7
...
@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor
...
@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor
return
Sequence
<
IDim
>
{};
return
Sequence
<
IDim
>
{};
}
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
template
<
class
IDim
>
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
GetLength
(
IDim
)
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
I
>
)
{
{
return
Lengths
::
Get
(
IDim
{});
return
Lengths
::
Get
(
Number
<
I
>
{});
}
}
template
<
class
IDim
>
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
GetStride
(
IDim
)
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
I
>
)
{
{
return
Strides
::
Get
(
IDim
{});
return
Strides
::
Get
(
Number
<
I
>
{});
}
}
struct
lambda_AreDimensionsContinuous
struct
lambda_AreDimensionsContinuous
...
@@ -102,18 +102,17 @@ struct ConstantTensorDescriptor
...
@@ -102,18 +102,17 @@ struct ConstantTensorDescriptor
return
false
;
return
false
;
}
}
__host__
__device__
static
constexpr
auto
GetElementSize
()
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
{
return
Number
<
accumulate_on_sequence
(
return
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
}
__host__
__device__
static
constexpr
auto
GetElementSpace
()
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
{
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
return
Number
<
element_space_unaligned
>
{}
;
return
element_space_unaligned
;
}
}
// emulate constexpr lambda
// emulate constexpr lambda
...
@@ -157,14 +156,13 @@ struct ConstantTensorDescriptor
...
@@ -157,14 +156,13 @@ struct ConstantTensorDescriptor
}
}
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
{
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
return
Number
<
accumulate_on_sequence
(
return
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{});
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
}
}
// emulate constexpr lambda
// emulate constexpr lambda
...
...
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
583755a7
...
@@ -142,47 +142,80 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -142,47 +142,80 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{}
&&
is_same
<
FloatC
,
float
>
{},
"Run_amd_asm only deal with float"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_amd_asm cannot deal with this GEMM shape yet"
);
"Run_amd_asm cannot deal with this GEMM shape yet"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
// If A and B datatype is float
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
&&
Float4
*
reg_a
=
reinterpret_cast
<
Float4
*>
(
p_a_thread
);
std
::
is_same
<
FloatB
,
float
>::
value
>
{}([
&
](
auto
)
{
Float4
*
reg_b
=
reinterpret_cast
<
Float4
*>
(
p_b_thread
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
Float4
*
reg_a
=
reinterpret_cast
<
Float4
*>
(
p_a_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
Float4
*
reg_b
=
reinterpret_cast
<
Float4
*>
(
p_b_thread
);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_a
[
1
]
=
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
reg_b
[
1
]
=
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
}).
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
using
Half4x4
=
vector_type
<
vector_type
<
half
,
4
>
,
4
>
;
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Half4x4
*
reg_a
=
reinterpret_cast
<
Half4x4
*>
(
p_a_thread
);
Half4x4
*
reg_b
=
reinterpret_cast
<
Half4x4
*>
(
p_b_thread
);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
#pragma unroll
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
});
}
}
#endif
#endif
...
@@ -204,11 +237,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -204,11 +237,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
make_ConstantMatrixDescriptor
(
Number
<
K
PerThread
Loop
>
{},
Number
<
MPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
M
PerThread
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
make_ConstantMatrixDescriptor
(
Number
<
K
PerThread
Loop
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
N
PerThread
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
...
@@ -217,8 +250,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -217,8 +250,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()
*
4
];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()
*
4
];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
...
@@ -237,10 +270,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -237,10 +270,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
threadwise_matrix_copy
(
threadwise_matrix_copy
(
a_block_mtx
,
a_block_mtx
,
p_a_block
+
p_a_block
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
(
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
mMyThreadOffsetA
)
*
4
,
a_thread_mtx
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
p_a_thread
+
(
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
)
)
*
4
,
a_thread_sub_mtx
.
GetLengths
(),
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
Number
<
DataPerReadA
>
{});
}
}
...
@@ -252,10 +285,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -252,10 +285,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
threadwise_matrix_copy
(
threadwise_matrix_copy
(
b_block_mtx
,
b_block_mtx
,
p_b_block
+
p_b_block
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
(
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
mMyThreadOffsetB
)
*
4
,
b_thread_mtx
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
p_b_thread
+
(
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
)
)
*
4
,
b_thread_sub_mtx
.
GetLengths
(),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
Number
<
DataPerReadB
>
{});
}
}
...
@@ -415,7 +448,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -415,7 +448,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
static_if
<
std
::
is_same
<
FloatA
,
ushort
>::
value
&&
std
::
is_same
<
FloatB
,
ushort
>::
value
>
{}(
[
&
](
auto
)
{
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
})
.
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
});
#else
#else
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
#endif
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
583755a7
...
@@ -10,13 +10,15 @@
...
@@ -10,13 +10,15 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#endif
#endif
#define JOINTCAT(x, y) x##y
#define ASSERT_MSG_ARG1(msg, var1) JOINTCAT(msg, var1)
#define ASSERT_MSG_ARG2(msg, var1, va2) ASSERT_MSG_ARG1(JOINTCAT(msg, var1), var2)
namespace
ck
{
namespace
ck
{
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// memory layout (ordering of dimensions) can be different between src and dst
// on a merged dimension that constains multiple original dimensions,
// For now, only support SubLengths[...] == 1 on a merged dimension
// its sub-length need to evenly divide the length of the last original dimension
// so each thread is effectively reading a normal (not merged) tensor
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
Float
,
class
SrcDesc
,
class
SrcDesc
,
...
@@ -77,15 +79,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -77,15 +79,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
// thread cluster
// thread cluster
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
DataClusterLengths
::
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
DataClusterLengths
{}.
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
// BlockSize
// BlockSize
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! BlockSize"
);
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! block size doesn't match with thread cluster size."
);
// divide work
// divide work
constexpr
auto
data_per_cluster_per_dims
=
SubLengths
{}
*
DataClusterLengths
{};
constexpr
auto
data_per_cluster_per_dims
=
SubLengths
{}
*
DataClusterLengths
{};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
static_assert
(
SliceLengths
::
Get
(
IDim
)
%
SubLengths
::
Get
(
IDim
)
==
0
,
static_assert
(
SliceLengths
::
Get
(
IDim
)
%
SubLengths
::
Get
(
IDim
)
==
0
,
"wrong! cannot evenly divide sliced tensor into sub-tensor"
);
"wrong! cannot evenly divide sliced tensor into sub-tensor"
);
...
@@ -93,23 +98,15 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -93,23 +98,15 @@ struct BlockwiseGenericTensorSliceCopy_v1
"wrong! cannot evenly divide sliced tensor into cluster"
);
"wrong! cannot evenly divide sliced tensor into cluster"
);
});
});
// on a merged dimension that constains multiple original dimensions,
// for now, only support SubLengths == 1 on a merged dimension that constains
// its sub-length need to evenly divide the length of the last original dimension,
// multiple original dimensions
// so each thread is effectively reading a normal (not merged) tensor
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
auto
sub_length
=
SubLengths
::
Get
(
IDim
);
static_assert
(
SubLengths
::
Get
(
IDim
)
==
1
||
constexpr
auto
idim_original_src
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Back
();
(
!
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
&&
static_assert
(
SrcDesc
::
GetOriginalTensorDescriptor
().
GetLength
(
idim_original_src
)
%
!
DstDesc
::
ContainMultipleOriginalDimensions
(
IDim
)),
sub_length
==
"wrong! only support Sub-Length == 1 on a merged dimension"
);
0
,
"wrong!"
);
constexpr
auto
idim_original_dst
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
).
Back
();
static_assert
(
DstDesc
::
GetOriginalTensorDescriptor
().
GetLength
(
idim_original_dst
)
%
sub_length
==
0
,
"wrong!"
);
});
});
// calculate mThreadSrcOffset, mThreadDstOffset
// calculate mThreadSrcOffset, mThreadDstOffset
...
@@ -129,25 +126,31 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -129,25 +126,31 @@ struct BlockwiseGenericTensorSliceCopy_v1
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
// partial offset on each dimension
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
;
constexpr
auto
src_partial_original_dims
=
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
mThreadSrcPartialOffsets
(
ID
im
)
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
mThreadSrcPartialOffsets
(
id
im
)
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
});
});
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
;
constexpr
auto
dst_partial_original_dims
=
constexpr
auto
dst_partial_original_dims
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
);
DstDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
dst_partial_original_desc
=
constexpr
auto
dst_partial_original_desc
=
DstDesc
::
GetOriginalTensorDescriptor
().
Extract
(
dst_partial_original_dims
);
DstDesc
::
GetOriginalTensorDescriptor
().
Extract
(
dst_partial_original_dims
);
mThreadDstPartialOffsets
(
ID
im
)
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
mThreadDstPartialOffsets
(
id
im
)
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadDstOriginalMultiId
,
dst_partial_original_dims
));
extract_array
(
mThreadDstOriginalMultiId
,
dst_partial_original_dims
));
});
});
...
@@ -181,8 +184,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -181,8 +184,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr
auto
thread_tensor_desc
=
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
src_thread_data_multi_id_begin
=
constexpr
auto
src_thread_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
repeat_multi_id
*
data_per_cluster_per_dims
;
...
@@ -195,25 +200,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -195,25 +200,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr
index_t
clipboard_offset
=
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
#else
#else
ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
sequence2array
(
decltype
(
repeat_multi_id_
){});
const
auto
src_thread_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
const
auto
src_thread_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
const
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
const
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
const
index_t
src_offset
=
const
index_t
src_offset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_thread_data_multi_id_begin
);
SrcDesc
{}.
GetOffsetFromMultiIndex
(
src_thread_data_multi_id_begin
);
const
index_t
clipboard_offset
=
const
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
#endif
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
{},
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
p_src
+
src_offset
+
mThreadSrcOffset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
...
@@ -238,8 +237,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -238,8 +237,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr
auto
thread_tensor_desc
=
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
clipboard_data_multi_id_begin
=
constexpr
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
repeat_multi_id
*
thread_sub_tensor_lengths
;
...
@@ -249,9 +250,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -249,9 +250,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
constexpr
index_t
dst_offset
=
constexpr
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#else
#else
ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id
)
{
constexpr
auto
repeat_multi_id
=
sequence2array
(
decltype
(
repeat_multi_id_
){});
const
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
const
auto
clipboard_data_multi_id_begin
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
const
auto
dst_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
const
auto
dst_data_multi_id_begin
=
repeat_multi_id
*
data_per_cluster_per_dims
;
...
@@ -259,16 +261,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -259,16 +261,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
index_t
clipboard_offset
=
const
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
const
index_t
dst_offset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#endif
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
...
@@ -302,7 +297,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -302,7 +297,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
__device__
void
MoveSlicingWindowOnSourceTensor
(
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
{
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
index_t
idim
=
IDim
;
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
)
{
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
)
{
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// logic for a merged dimension, also works for non-merged dimension, but its logic may
...
@@ -325,21 +321,22 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -325,21 +321,22 @@ struct BlockwiseGenericTensorSliceCopy_v1
old_src_partial_original_multi_id
,
StepSize
,
direction
);
old_src_partial_original_multi_id
,
StepSize
,
direction
);
// update "mThreadSrcOriginalMultiId"
// update "mThreadSrcOriginalMultiId"
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I_
)
{
constexpr
auto
IDimOriginal
=
src_partial_original_dims
[
I
];
constexpr
auto
I
=
decltype
(
I_
){};
constexpr
index_t
idim_original
=
src_partial_original_dims
.
Get
(
I
);
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
=
new_src_partial_original_multi_id
[
I
];
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
=
new_src_partial_original_multi_id
[
I
];
});
});
// calculate new partial offset on this merged dimension
// calculate new partial offset on this merged dimension
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
ID
im
];
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
id
im
];
const
index_t
new_src_partial_offset
=
const
index_t
new_src_partial_offset
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
new_src_partial_original_multi_id
);
new_src_partial_original_multi_id
);
// update "mThreadSrcPartialOffsets"
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets
(
ID
im
)
=
new_src_partial_offset
;
mThreadSrcPartialOffsets
(
id
im
)
=
new_src_partial_offset
;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset
=
(
mThreadSrcOffset
+
new_src_partial_offset
)
-
old_src_partial_offset
;
mThreadSrcOffset
=
(
mThreadSrcOffset
+
new_src_partial_offset
)
-
old_src_partial_offset
;
...
@@ -354,20 +351,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -354,20 +351,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr
auto
IDimO
riginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
constexpr
index_t
idim_o
riginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
+=
StepSize
;
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
+=
StepSize
;
mThreadSrcPartialOffsets
(
ID
im
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcPartialOffsets
(
id
im
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
}).
Else
([
&
](
auto
fwd
)
{
}).
Else
([
&
](
auto
fwd
)
{
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
-=
StepSize
;
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
-=
StepSize
;
mThreadSrcPartialOffsets
(
ID
im
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcPartialOffsets
(
id
im
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
});
});
});
});
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
583755a7
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "float_types.h"
namespace
ck
{
namespace
ck
{
...
@@ -34,22 +35,60 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -34,22 +35,60 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
{
{
static_assert
(
NCol
%
DataPerRead
==
0
,
"wrong! should be NCol % == DataPerRead == 0"
);
static_assert
(
NCol
%
DataPerRead
==
0
,
"wrong! should be NCol % == DataPerRead == 0"
);
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
// Depending upon datatype i.e float/half/bfloat16, carry out data movement
{
// in appropriate vectorized form
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
// float - 4, half - 4, bfloat16 - 2
static_if
<
std
::
is_same
<
Float
,
float
>::
value
>
{}([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
float
,
DataPerRead
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
4
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
4
]);
}
}
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
Float
,
2
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
2
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
2
]);
}
}
});
});
}
}
template
<
class
MatrixA
,
template
<
class
MatrixA
,
...
@@ -80,6 +119,7 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -80,6 +119,7 @@ __device__ void threadwise_gemm(MatrixA,
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
...
@@ -90,7 +130,32 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -90,7 +130,32 @@ __device__ void threadwise_gemm(MatrixA,
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
p_c_thread
[
cindex
]
+=
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
];
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
>
{}([
&
](
auto
)
{
p_c_thread
[
cindex
]
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
]);
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
FloatA
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
4
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
4
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
4
+
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
}).
Else
([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
2
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
2
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
2
+
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
});
});
}
}
}
}
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
583755a7
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "float_types.h"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
...
@@ -12,7 +13,8 @@
...
@@ -12,7 +13,8 @@
namespace
ck
{
namespace
ck
{
// user need to make sure alignment requirement is satisfied when setting DataPerAccesss > 1
// user need to make sure alignment requirement is satisfied when setting DataPerAccesss > 1
template
<
class
Float
,
template
<
class
SrcFloat
,
class
DesFloat
,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SliceLengths
,
class
SliceLengths
,
...
@@ -20,10 +22,10 @@ template <class Float,
...
@@ -20,10 +22,10 @@ template <class Float,
index_t
DataPerAccess
>
index_t
DataPerAccess
>
__device__
void
threadwise_generic_tensor_slice_copy_v1
(
__device__
void
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
,
SrcDesc
,
const
Float
*
__restrict__
p_src
,
const
Src
Float
*
__restrict__
p_src
,
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_id_begin
,
Array
<
index_t
,
SrcDesc
::
GetNumOfDimension
()
>
src_multi_id_begin
,
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Des
Float
*
__restrict__
p_dst
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_id_begin
,
Array
<
index_t
,
DstDesc
::
GetNumOfDimension
()
>
dst_multi_id_begin
,
SliceLengths
,
SliceLengths
,
DimAccessOrder
,
DimAccessOrder
,
...
@@ -64,7 +66,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -64,7 +66,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
constexpr
auto
access_lengths
=
slice_lengths_in_access_order
.
Modify
(
constexpr
auto
access_lengths
=
slice_lengths_in_access_order
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
num_access_on_lowest_access_dimension
>
{});
Number
<
nDim
-
1
>
{},
Number
<
num_access_on_lowest_access_dimension
>
{});
using
vector_t
=
typename
vector_type
<
Float
,
DataPerAccess
>::
MemoryType
;
using
vector_src_t
=
typename
vector_type
<
SrcFloat
,
DataPerAccess
>::
MemoryType
;
using
vector_dest_t
=
typename
vector_type
<
DesFloat
,
DataPerAccess
>::
MemoryType
;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
...
@@ -82,8 +85,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -82,8 +85,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
const
index_t
dst_index
=
const
index_t
dst_index
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
p_dst
[
dst_index
+
data_idx
]
=
CVT_ACCUM2FLOAT
(
p_src
[
src_index
+
data_idx
]);
}
});
});
});
#else
#else
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
...
@@ -99,8 +109,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -99,8 +109,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
const
index_t
dst_index
=
const
index_t
dst_index
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
DstDesc
::
GetOffsetFromMultiIndex
(
dst_multi_id_begin
+
data_multi_id
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
p_dst
[
dst_index
+
data_idx
]
=
CVT_ACCUM2FLOAT
(
p_src
[
src_index
+
data_idx
]);
}
});
});
});
#endif
#endif
}
}
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
583755a7
...
@@ -16,32 +16,31 @@ struct Sequence
...
@@ -16,32 +16,31 @@ struct Sequence
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Number
<
mSize
>
{}
;
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
__host__
__device__
static
constexpr
index_t
GetImpl
(
index_t
I
)
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
Get
(
Number
<
I
>
)
{
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
return
mData
[
I
];
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
Get
(
Number
<
I
>
)
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
)
const
{
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
Get
(
Number
<
I
>
{})
>
{};
return
Number
<
GetImpl
(
Number
<
I
>
{})
>
{};
}
}
template
<
index_t
I
>
// make sure I is constepxr
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
)
const
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
{
return
Get
(
Number
<
I
>
{});
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
}
// make sure I is constepxr if you want a constexpr return type
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
GetImpl
(
I
);
}
template
<
index_t
...
IRs
>
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
{
{
...
@@ -55,16 +54,16 @@ struct Sequence
...
@@ -55,16 +54,16 @@ struct Sequence
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
auto
Front
()
__host__
__device__
static
constexpr
index_t
Front
()
{
{
static_assert
(
mSize
>
0
,
"wrong!"
)
;
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
}
;
return
Get
(
Number
<
0
>
{})
;
return
mData
[
0
]
;
}
}
__host__
__device__
static
constexpr
auto
Back
()
__host__
__device__
static
constexpr
index_t
Back
()
{
{
static_assert
(
mSize
>
0
,
"wrong!"
)
;
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
}
;
return
Get
(
Number
<
mSize
-
1
>
{})
;
return
mData
[
mSize
-
1
]
;
}
}
__host__
__device__
static
constexpr
auto
PopFront
();
__host__
__device__
static
constexpr
auto
PopFront
();
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
583755a7
...
@@ -118,6 +118,58 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
...
@@ -118,6 +118,58 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
outerProduct1x4
(
a
.
w
,
b
,
c3
);
outerProduct1x4
(
a
.
w
,
b
,
c3
);
}
}
__device__
void
outerProduct1x4
(
const
half2
*
a
,
const
half2
*
b
,
float
*
c
)
{
asm
volatile
(
"
\n
\
v_dot2_f32_f16 %0, %4, %6 %0
\n
\
v_dot2_f32_f16 %1, %4, %8 %1
\n
\
v_dot2_f32_f16 %2, %4, %10 %2
\n
\
v_dot2_f32_f16 %3, %4, %12 %3
\n
\
v_dot2_f32_f16 %0, %5, %7 %0
\n
\
v_dot2_f32_f16 %1, %5, %9 %1
\n
\
v_dot2_f32_f16 %2, %5, %11 %2
\n
\
v_dot2_f32_f16 %3, %5, %13 %3
\n
\
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
]),
"=v"
(
c
[
2
]),
"=v"
(
c
[
3
])
// Dest registers
:
"v"
(
a
[
0
]),
"v"
(
a
[
1
]),
// 1st Src registers for 2 half2 registers
"v"
(
b
[
0
]),
"v"
(
b
[
1
]),
"v"
(
b
[
2
]),
"v"
(
b
[
3
]),
// 2nd Src registers for 2 half2 registers
"v"
(
b
[
4
]),
"v"
(
b
[
5
]),
"v"
(
b
[
6
]),
"v"
(
b
[
7
]),
// 2nd Src registers for 2 half2 registers
"0"
(
c
[
0
]),
"1"
(
c
[
1
]),
"2"
(
c
[
2
]),
"3"
(
c
[
3
]));
// 3rd Src Acc registers for 2 half2 registers
}
__device__
void
outerProduct1x4Half
(
const
vector_type
<
half
,
4
>&
a
,
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
{
outerProduct1x4
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
reinterpret_cast
<
const
half2
*>
(
&
b
),
reinterpret_cast
<
float
*>
(
&
c
));
}
__device__
void
outerProduct4x4
(
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
a
,
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c0
,
vector_type
<
float
,
4
>::
MemoryType
&
c1
,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
{
const
vector_type
<
half
,
4
>*
reg_a
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>*>
(
&
a
);
outerProduct1x4Half
(
reg_a
[
0
],
b
,
c0
);
outerProduct1x4Half
(
reg_a
[
1
],
b
,
c1
);
outerProduct1x4Half
(
reg_a
[
2
],
b
,
c2
);
outerProduct1x4Half
(
reg_a
[
3
],
b
,
c3
);
}
__device__
void
outerProduct8x8
(
const
vector_type
<
float
,
4
>::
MemoryType
*
a
,
__device__
void
outerProduct8x8
(
const
vector_type
<
float
,
4
>::
MemoryType
*
a
,
const
vector_type
<
float
,
4
>::
MemoryType
*
b
,
const
vector_type
<
float
,
4
>::
MemoryType
*
b
,
vector_type
<
float
,
4
>::
MemoryType
*
c
)
vector_type
<
float
,
4
>::
MemoryType
*
c
)
...
...
composable_kernel/include/utility/bfloat16_dev.hpp
0 → 100644
View file @
583755a7
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#ifdef __cplusplus
extern
"C"
{
#endif
#ifdef __HIP_PLATFORM_HCC__
#define EXECUTION_SPECIFIER __device__
#else
#define EXECUTION_SPECIFIER
#endif // MIOPEN_BACKEND_HIP
typedef
union
{
uint
u32
;
ushort2
ushortx2
;
// Composable kernels are written in HIP language. The language doesnt support
// ushort2.hi or ushort2.low.
#ifdef __HIP_PLATFORM_HCC__
ushort
ushortvec
[
2
];
#endif // MIOPEN_BACKEND_HIP
float
f32
;
}
cvt_bf16_fp32_t
;
EXECUTION_SPECIFIER
float
bfloat16_to_float
(
ushort
src_val
)
{
cvt_bf16_fp32_t
target_val
;
#ifdef __HIP_PLATFORM_HCC__
target_val
.
ushortx2
=
make_ushort2
(
0
,
src_val
);
#else
target_val
.
ushortx2
=
(
ushort2
)(
0
,
src_val
);
#endif
return
target_val
.
f32
;
}
EXECUTION_SPECIFIER
ushort
float_to_bfloat16
(
float
src_val
)
{
cvt_bf16_fp32_t
target_val
;
target_val
.
f32
=
src_val
;
// BF16 round and NaN preservation code matches
// https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h
if
((
~
target_val
.
u32
&
0x7f800000
)
==
0
)
// Inf or NaN
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
if
((
target_val
.
u32
&
0xffff
)
!=
0
)
{
target_val
.
u32
|=
0x10000
;
// Preserve signaling NaN
}
}
else
{
#ifdef MIOPEN_USE_RNE_BFLOAT16
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
#ifdef __HIP_PLATFORM_HCC__
target_val
.
u32
+=
(
0x7fff
+
(
target_val
.
ushortvec
[
0
]
&
1
));
#else
target_val
.
u32
+=
(
0x7fff
+
(
target_val
.
ushortx2
.
hi
&
1
));
// Round to nearest, round to even
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_RNE_BFLOAT16
}
#ifdef __HIP_PLATFORM_HCC__
return
target_val
.
ushortvec
[
0
];
#else
return
target_val
.
ushortx2
.
hi
;
#endif // MIOPEN_BACKEND_HIP
}
#ifdef __cplusplus
}
#endif
#endif // BFLOAT16_DEVICE_HPP
composable_kernel/include/utility/common_header.hpp
View file @
583755a7
#ifndef CK_COMMON_HEADER_HPP
#ifndef CK_COMMON_HEADER_HPP
#define CK_COMMON_HEADER_HPP
#define CK_COMMON_HEADER_HPP
#define MIOPEN_USE_FP16 1
#define MIOPEN_USE_BFP16 0
#define MIOPEN_USE_FP32 0
#define __HIP_PLATFORM_HCC__ 1
#include "config.hpp"
#include "config.hpp"
#include "utility.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
...
...
composable_kernel/include/utility/float_types.h
0 → 100644
View file @
583755a7
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef FLOAT_TYPES_HPP
#define FLOAT_TYPES_HPP
#include "bfloat16_dev.hpp"
#define PPCAT_NX(A, B) A##B
#define PPCAT(A, B) PPCAT_NX(A, B)
#define TWO 2
#define FOUR 4
#define EIGHT 8
#if MIOPEN_USE_FP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT half
#define FLOAT_ACCUM float
#else
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define _FLOAT half
#define _FLOAT_ACCUM float
#endif // __HIP_PLATFORM_HCC__
#define SIZEOF_FLOAT 2
/* sizeof is unavailable for preprocessor */
#ifndef HALF_MAX
#define MAX_VAL 65504
/* max value */
#else
#define MAX_VAL HALF_MAX
#endif // HALF_MAX
#endif // MIOPEN_USE_FP16
#if MIOPEN_USE_FP32 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT float
#define FLOAT_ACCUM float
#else
#define _FLOAT float
#define _FLOAT_ACCUM float
#endif // __HIP_PLATFORM_HCC__
#define SIZEOF_FLOAT 4
/* sizeof is unavailable for preprocessor */
#ifndef FLT_MAX
#define MAX_VAL 3.402823466e+38F
/* max value */
#else
#define MAX_VAL FLT_MAX
#endif // FLT_MAX
#endif // MIOPEN_USE_FP32
#if MIOPEN_USE_BFP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define FLOAT ushort
#define FLOAT_ACCUM float
#else
#define _FLOAT ushort
#define _FLOAT_ACCUM float
#endif //
#define SIZEOF_FLOAT 2
/* sizeof is unavailable for preprocessor */
#define MAX_VAL 0x7F7F
/* max value */
#endif // MIOPEN_USE_BFP16
#if MIOPEN_USE_FP16 == 1
#ifdef __HIP_PLATFORM_HCC__
#define CVT_FLOAT2ACCUM(x) (static_cast<FLOAT_ACCUM>(x))
#define CVT_ACCUM2FLOAT(x) (static_cast<FLOAT>(x))
#else
#define CVT_FLOAT2ACCUM(x) ((_FLOAT_ACCUM)(x))
#define CVT_ACCUM2FLOAT(x) ((_FLOAT)(x))
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_FP16
#if MIOPEN_USE_FP32 == 1
#ifdef __HIP_PLATFORM_HCC__
#define CVT_FLOAT2ACCUM(x) (static_cast<FLOAT_ACCUM>(x))
#define CVT_ACCUM2FLOAT(x) (static_cast<FLOAT>(x))
#else
#define CVT_FLOAT2ACCUM(x) ((_FLOAT_ACCUM)(x))
#define CVT_ACCUM2FLOAT(x) ((_FLOAT)(x))
#endif
#endif // MIOPEN_USE_FP32
#if MIOPEN_USE_BFP16 == 1
#define CVT_FLOAT2ACCUM(x) bfloat16_to_float(x)
#define CVT_ACCUM2FLOAT(x) float_to_bfloat16(x)
#endif
#ifndef __HIP_PLATFORM_HCC__
#define _FLOAT2 PPCAT(_FLOAT, TWO)
#endif
#endif // FLOAT_TYPES_HPP
composable_kernel/include/utility/integral_constant.hpp
View file @
583755a7
...
@@ -13,64 +13,30 @@ struct integral_constant
...
@@ -13,64 +13,30 @@ struct integral_constant
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
};
template
<
class
X
,
class
Y
>
template
<
class
T
,
T
X
,
T
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
{
};
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
{
};
return
integral_constant
<
T
,
X
+
Y
>
{};
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
+
Y
>
{};
}
}
template
<
index_t
X
,
index_t
Y
>
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
__host__
__device__
constexpr
auto
operator
*
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
{
{
static_assert
(
Y
<=
X
,
"wrong!"
);
return
integral_constant
<
T
,
X
*
Y
>
{};
return
Number
<
X
-
Y
>
{};
}
}
template
<
index_t
X
,
index_t
Y
>
template
<
index_t
N
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
using
Number
=
integral_constant
<
index_t
,
N
>
;
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
{
static_assert
(
Y
>
0
,
"wrong!"
);
};
return
Number
<
X
/
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
template
<
class
X
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
{
static_assert
(
Y
>
0
,
"wrong!"
);
};
return
Number
<
X
%
Y
>
{};
}
#if 0
static constexpr Number<0> 0_c;
static constexpr Number<1> 1_c;
static constexpr Number<2> 2_c;
static constexpr Number<3> 3_c;
static constexpr Number<4> 4_c;
static constexpr Number<5> 5_c;
static constexpr Number<6> 6_c;
static constexpr Number<7> 7_c;
static constexpr Number<8> 8_c;
static constexpr Number<9> 9_c;
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/math.hpp
View file @
583755a7
...
@@ -42,16 +42,20 @@ struct integer_divide_ceiler
...
@@ -42,16 +42,20 @@ struct integer_divide_ceiler
}
}
};
};
template
<
class
X
,
class
Y
>
template
<
class
T
>
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
__host__
__device__
constexpr
T
integer_divide_ceil
(
T
a
,
T
b
)
{
{
return
(
x
+
y
-
1
)
/
y
;
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
}
}
template
<
class
X
,
class
Y
>
template
<
class
T
>
__host__
__device__
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
__host__
__device__
constexpr
T
integer_least_multiple
(
T
a
,
T
b
)
{
{
return
y
*
integer_divide_ceil
(
x
,
y
);
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
b
*
integer_divide_ceil
(
a
,
b
);
}
}
template
<
class
T
>
template
<
class
T
>
...
...
composable_kernel/include/utility/vector_type.hpp
View file @
583755a7
#ifndef CK_VECTOR_TYPE_HPP
#ifndef CK_VECTOR_TYPE_HPP
#define CK_VECTOR_TYPE_HPP
#define CK_VECTOR_TYPE_HPP
#include "cuda_fp16.h"
#include "config.hpp"
#include "config.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
...
@@ -9,12 +10,15 @@ namespace ck {
...
@@ -9,12 +10,15 @@ namespace ck {
template
<
class
T
,
index_t
N
>
template
<
class
T
,
index_t
N
>
struct
vector_type
struct
vector_type
{
{
T
vector
[
N
];
};
};
template
<
>
template
<
>
struct
vector_type
<
float
,
1
>
struct
vector_type
<
float
,
1
>
{
{
typedef
float
MemoryType
;
using
MemoryType
=
float
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
1
;
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
...
@@ -29,6 +33,8 @@ struct vector_type<float, 2>
...
@@ -29,6 +33,8 @@ struct vector_type<float, 2>
{
{
using
MemoryType
=
float2_t
;
using
MemoryType
=
float2_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
union
Data
union
Data
{
{
MemoryType
vector
;
MemoryType
vector
;
...
@@ -42,13 +48,6 @@ struct vector_type<float, 2>
...
@@ -42,13 +48,6 @@ struct vector_type<float, 2>
*
(
reinterpret_cast
<
float
*>
(
&
v
)
+
I
)
=
s
;
*
(
reinterpret_cast
<
float
*>
(
&
v
)
+
I
)
=
s
;
}
}
__host__
__device__
static
MemoryType
Pack
(
float
s0
,
float
s1
)
{
Data
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
return
data
.
vector
;
}
};
};
template
<
>
template
<
>
...
@@ -56,6 +55,8 @@ struct vector_type<float, 4>
...
@@ -56,6 +55,8 @@ struct vector_type<float, 4>
{
{
using
MemoryType
=
float4_t
;
using
MemoryType
=
float4_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
{
{
...
@@ -64,6 +65,116 @@ struct vector_type<float, 4>
...
@@ -64,6 +65,116 @@ struct vector_type<float, 4>
}
}
};
};
template
<
>
struct
vector_type
<
half
,
1
>
{
using
MemoryType
=
half
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
1
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
{
static_assert
(
I
<
1
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
half
,
2
>
{
using
MemoryType
=
half2
;
union
Data
{
MemoryType
vector
;
half
scalar
[
2
];
};
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
{
static_assert
(
I
<
2
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
half
,
4
>
{
typedef
struct
MemoryType
{
half2
vector
[
2
];
}
MemoryType
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
{
static_assert
(
I
<
4
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
ushort
,
1
>
{
using
MemoryType
=
ushort
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
1
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
{
static_assert
(
I
<
1
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
ushort
,
2
>
{
using
MemoryType
=
ushort2
;
union
Data
{
MemoryType
vector
;
half
scalar
[
2
];
};
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
{
static_assert
(
I
<
2
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
};
template
<
>
struct
vector_type
<
ushort
,
4
>
{
typedef
struct
MemoryType
{
ushort2
vector
[
2
];
}
MemoryType
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
{
static_assert
(
I
<
4
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
583755a7
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#define MIOPEN_USE_FP16 1
#define MIOPEN_USE_BFP16 0
#define MIOPEN_USE_FP32 0
#define __HIP_PLATFORM_HCC__ 1
#include "float_types.h"
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#define CK_PARAM_TUNABLE_K_PER_BLOCK 64
using
namespace
ck
;
using
namespace
ck
;
...
@@ -24,6 +35,10 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -24,6 +35,10 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
ConvDilations
,
ConvDilations
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
// read params: problem decription
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -59,16 +74,22 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -59,16 +74,22 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
K
%
128
==
0
?
128
:
(
K
%
64
==
0
?
64
:
32
);
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
BlockSize
=
K
%
128
==
0
?
256
:
(
K
%
64
==
0
?
128
:
64
);
#if MIOPEN_USE_FP16 == 1
// ES set to 4 as dot4 operator is supported on fp16 in MI100
constexpr
index_t
ES
=
4
;
#elif MIOPEN_USE_BFP16 == 1
// ES set to 2 as dot2 operator is supported on bfp16 in MI100
constexpr
index_t
ES
=
2
;
#else
// do nothing
#endif
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
...
@@ -76,92 +97,103 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -76,92 +97,103 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
#if MIOPEN_USE_FP32 == 1
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
,
4
>
;
// [E, N1, N2, B, ES]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
,
4
>
;
// [E, N1, N2, B, ES]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
,
4
>
;
// [E, N1, B, N2, ES]
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
,
2
>
;
// [K, E, ES]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
,
2
>
;
// [K, E, ES]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
>
;
// [E, K, ES]
#endif
#if CK_PARAM_TUNABLE_K_PER_BLOCK == 32
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
#if MIOPEN_USE_FP32 == 1
// all_of(X_Per_Block % (X_Sub_Length * X_Cluster_Length) == 0)
// accumulate(X_Cluster_Lengths, multiply) == BlockSize
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
2
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
32
>
;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
using
InBlockCopySubLengths_E_N1_B_N2_ES
=
Sequence
<
1
,
2
,
1
,
4
,
ES
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2_ES
=
Sequence
<
4
,
1
,
16
,
1
,
1
>
;
using
WeiBlockCopySubLengths_E_K_ES
=
Sequence
<
2
,
1
,
ES
>
;
using
WeiBlockCopyClusterLengths_E_K_ES
=
Sequence
<
2
,
32
,
1
>
;
#endif // MIOPEN_USE_FP32 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
#elif CK_PARAM_TUNABLE_K_PER_BLOCK == 64
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
4
,
1
>
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
4
,
4
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
4
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
#if MIOPEN_USE_FP32 == 1
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
16
,
1
>
;
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
64
>
;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
// ES - E dimension is folded into 2 dimensions E and ES
using
InBlockCopySubLengths_E_N1_B_N2_ES
=
Sequence
<
1
,
2
,
1
,
4
,
ES
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2_ES
=
Sequence
<
8
,
1
,
16
,
1
,
1
>
;
using
WeiBlockCopySubLengths_E_K_ES
=
Sequence
<
4
,
1
,
ES
>
;
using
WeiBlockCopyClusterLengths_E_K_ES
=
Sequence
<
2
,
64
,
1
>
;
#endif // MIOPEN_USE_FP32 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
#elif CK_PARAM_TUNABLE_K_PER_BLOCK == 128
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
2
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
8
,
2
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
2
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
2
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
#if MIOPEN_USE_FP32 == 1
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
// ES - E dimension is folded into 2 dimensions E and ES
using
InBlockCopySubLengths_E_N1_B_N2_ES
=
Sequence
<
1
,
1
,
1
,
4
,
ES
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2_ES
=
Sequence
<
8
,
2
,
16
,
1
,
1
>
;
using
WeiBlockCopySubLengths_E_K_ES
=
Sequence
<
4
,
1
,
ES
>
;
using
WeiBlockCopyClusterLengths_E_K_ES
=
Sequence
<
2
,
128
,
1
>
;
#endif // MIOPEN_USE_FP32 == 1
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#endif
#else
static_assert
(
false
,
"wrong! Only kperblock could be 32/64/128 not supported"
);
#endif // CK_PARAM_TUNABLE_K_PER_BLOCK == 32
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
...
@@ -171,47 +203,86 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -171,47 +203,86 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if 0
#if MIOPEN_USE_FP32 == 1
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
<
#else
GridSize
,
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
BlockSize
,
#endif
FLOAT
,
<
GridSize
,
FLOAT_ACCUM
,
BlockSize
,
decltype
(
in_nchw_desc
),
T
,
decltype
(
wei_kcyx_desc
),
decltype
(
in_nchw_desc
),
decltype
(
out_nkhw_desc
),
decltype
(
wei_kcyx_desc
),
ConvStrides
,
decltype
(
out_nkhw_desc
),
ConvDilations
,
ConvStrides
,
BPerBlock
,
ConvDilations
,
KPerBlock
,
BPerBlock
,
EPerBlock
,
KPerBlock
,
N1
,
EPerBlock
,
N2
,
N1
,
GemmMPerThreadSubC
,
N2
,
GemmNPerThreadSubC
,
GemmMPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNPerThreadSubC
,
GemmNLevel0Cluster
,
GemmMLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel0Cluster
,
GemmNLevel1Cluster
,
GemmMLevel1Cluster
,
GemmKPerThreadLoop
,
GemmNLevel1Cluster
,
GemmDataPerReadA
,
GemmKPerThreadLoop
,
GemmDataPerReadB
,
GemmDataPerReadA
,
InBlockCopySubLengths_E_N1_B_N2
,
GemmDataPerReadB
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySrcAccessOrder
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstAccessOrder
,
InBlockCopyDstDataPerWrite_N2
,
InBlockCopySrcDataPerRead_B
,
WeiBlockCopySubLengths_E_K
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopyDstDataPerWrite_K
>
{};
WeiBlockCopySrcDataPerRead_E
,
#elif MIOPEN_USE_FP16 == 1 || MIOPEN_USE_BFP16 == 1
WeiBlockCopyDstDataPerWrite_K
>
{};
GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
half
,
float
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
BPerBlock
,
KPerBlock
,
EPerBlock
,
N1
,
N2
,
ES
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2_ES
,
InBlockCopyClusterLengths_E_N1_B_N2_ES
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_E_K_ES
,
WeiBlockCopyClusterLengths_E_K_ES
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
#endif
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
...
driver/src/driver.cpp
View file @
583755a7
...
@@ -787,14 +787,27 @@ int main(int argc, char* argv[])
...
@@ -787,14 +787,27 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
32
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
192
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 1
constexpr
index_t
N
=
8
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
4
;
constexpr
index_t
WI
=
4
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -802,7 +815,7 @@ int main(int argc, char* argv[])
...
@@ -802,7 +815,7 @@ int main(int argc, char* argv[])
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#endif
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
...
@@ -817,8 +830,8 @@ int main(int argc, char* argv[])
...
@@ -817,8 +830,8 @@ int main(int argc, char* argv[])
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_ConstantTensorDescriptor
(
out_nkhw_desc
,
std
::
cout
<<
"out_nkhw_desc: "
);
ostream_ConstantTensorDescriptor
(
out_nkhw_desc
,
std
::
cout
<<
"out_nkhw_desc: "
);
using
in_data_t
=
float
;
using
in_data_t
=
half
;
using
out_data_t
=
float
;
using
out_data_t
=
half
;
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
...
@@ -897,7 +910,7 @@ int main(int argc, char* argv[])
...
@@ -897,7 +910,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
#if
1
#if
0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
{
...
@@ -915,7 +928,6 @@ int main(int argc, char* argv[])
...
@@ -915,7 +928,6 @@ int main(int argc, char* argv[])
upper_pads
);
upper_pads
);
}
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#if 0
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment