Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
37b82b7e
Commit
37b82b7e
authored
Jun 19, 2019
by
Chao Liu
Browse files
refactor
parent
1f2cfceb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
138 additions
and
85 deletions
+138
-85
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+18
-4
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
...ude/tensor_description/ConstantMergedTensorDescriptor.hpp
+4
-4
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+15
-13
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+29
-23
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+18
-17
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+48
-14
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+6
-10
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
37b82b7e
...
...
@@ -84,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
...
...
@@ -92,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
,
"wrong! global vector load of input tensor is wrong"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
(
2
*
EPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
...
...
@@ -111,15 +125,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
s
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
s
::
Get
(
I1
)
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStride
H
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStride
W
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
s
::
Get
(
I0
)
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
s
::
Get
(
I1
)
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilation
H
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilation
W
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...
...
composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor.hpp
View file @
37b82b7e
...
...
@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return
OriginalTensorDesc
{};
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
...
...
@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
mOriginalDimMergeSeqs
);
...
...
@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
static_assert
(
!
ContainMultipleOriginalDimensions
(
Number
<
IDim
>
{}),
"wrong! stride of a merged dimension is undefined"
);
...
...
@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor
return
Sequence
<
OriginalTensorDesc
::
Extract
(
OriginalDimMergeSeqs
{}).
GetElementSize
()...
>
{};
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
return
OriginalTensorDesc
::
GetElementSize
();
}
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
37b82b7e
...
...
@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor
return
Sequence
<
IDim
>
{};
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{}
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
I
>
)
template
<
class
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
IDim
)
{
return
Lengths
::
Get
(
Number
<
I
>
{});
return
Lengths
::
Get
(
IDim
{});
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
I
>
)
template
<
class
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
IDim
)
{
return
Strides
::
Get
(
Number
<
I
>
{});
return
Strides
::
Get
(
IDim
{});
}
struct
lambda_AreDimensionsContinuous
...
...
@@ -102,17 +102,18 @@ struct ConstantTensorDescriptor
return
false
;
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
__host__
__device__
static
constexpr
auto
GetElementSize
()
{
return
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
return
Number
<
accumulate_on_sequence
(
Lengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
__host__
__device__
static
constexpr
auto
GetElementSpace
()
{
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
return
element_space_unaligned
;
return
Number
<
element_space_unaligned
>
{}
;
}
// emulate constexpr lambda
...
...
@@ -156,13 +157,14 @@ struct ConstantTensorDescriptor
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
__host__
__device__
static
constexpr
auto
GetOffsetFromMultiIndex
(
Sequence
<
Is
...
>
)
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
return
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{});
return
Number
<
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{})
>
{};
}
// emulate constexpr lambda
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
37b82b7e
...
...
@@ -83,9 +83,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
// divide work
constexpr
auto
data_per_cluster_per_dims
=
SubLengths
{}
*
DataClusterLengths
{};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_assert
(
SliceLengths
::
Get
(
IDim
)
%
SubLengths
::
Get
(
IDim
)
==
0
,
"wrong! cannot evenly divide sliced tensor into sub-tensor"
);
...
...
@@ -95,9 +93,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
// for now, only support SubLengths == 1 on a merged dimension that constains
// multiple original dimensions
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
static_assert
(
SubLengths
::
Get
(
IDim
)
==
1
||
(
!
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
&&
!
DstDesc
::
ContainMultipleOriginalDimensions
(
IDim
)),
...
...
@@ -121,8 +117,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
;
constexpr
auto
src_partial_original_dims
=
...
...
@@ -135,8 +130,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
});
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
;
constexpr
auto
dst_partial_original_dims
=
...
...
@@ -208,6 +202,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
make_zero_array
<
index_t
,
nDim
>
(),
...
...
@@ -259,6 +260,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
make_zero_array
<
index_t
,
nDim
>
(),
...
...
@@ -292,8 +300,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
index_t
idim
=
IDim
;
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
)
{
// logic for a merged dimension, also works for non-merged dimension, but its logic may
...
...
@@ -316,22 +323,21 @@ struct BlockwiseGenericTensorSliceCopy_v1
old_src_partial_original_multi_id
,
StepSize
,
direction
);
// update "mThreadSrcOriginalMultiId"
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I_
)
{
constexpr
auto
I
=
decltype
(
I_
){};
constexpr
index_t
idim_original
=
src_partial_original_dims
.
Get
(
I
);
static_for
<
0
,
decltype
(
src_partial_original_dims
)
::
GetSize
(),
1
>
{}([
&
](
auto
I
)
{
constexpr
auto
IDimOriginal
=
src_partial_original_dims
[
I
];
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
=
new_src_partial_original_multi_id
[
I
];
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
=
new_src_partial_original_multi_id
[
I
];
});
// calculate new partial offset on this merged dimension
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
id
im
];
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
ID
im
];
const
index_t
new_src_partial_offset
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
new_src_partial_original_multi_id
);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets
(
id
im
)
=
new_src_partial_offset
;
mThreadSrcPartialOffsets
(
ID
im
)
=
new_src_partial_offset
;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset
=
(
mThreadSrcOffset
+
new_src_partial_offset
)
-
old_src_partial_offset
;
...
...
@@ -346,20 +352,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr
index_t
idim_o
riginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
constexpr
auto
IDimO
riginal
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
+=
StepSize
;
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
+=
StepSize
;
mThreadSrcPartialOffsets
(
id
im
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcPartialOffsets
(
ID
im
)
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
}).
Else
([
&
](
auto
fwd
)
{
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
(
idim_o
riginal
)
-=
StepSize
;
mThreadSrcOriginalMultiId
(
IDimO
riginal
)
-=
StepSize
;
mThreadSrcPartialOffsets
(
id
im
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
mThreadSrcPartialOffsets
(
ID
im
)
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
});
});
}
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
37b82b7e
...
...
@@ -16,31 +16,32 @@ struct Sequence
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Number
<
mSize
>
{}
;
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
index_t
Get
(
Number
<
I
>
)
__host__
__device__
static
constexpr
index_t
GetImpl
(
index_t
I
)
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
static
constexpr
auto
Get
(
Number
<
I
>
)
{
return
Number
<
Get
(
Number
<
I
>
{})
>
{};
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
GetImpl
(
Number
<
I
>
{})
>
{};
}
// make sure I is constepxr
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
)
const
{
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
return
Get
(
Number
<
I
>
{});
}
// make sure I is constepxr if you want a constexpr return type
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
GetImpl
(
I
);
}
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
{
...
...
@@ -54,16 +55,16 @@ struct Sequence
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
index_t
Front
()
__host__
__device__
static
constexpr
auto
Front
()
{
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
}
;
return
mData
[
0
]
;
static_assert
(
mSize
>
0
,
"wrong!"
)
;
return
Get
(
Number
<
0
>
{})
;
}
__host__
__device__
static
constexpr
index_t
Back
()
__host__
__device__
static
constexpr
auto
Back
()
{
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
}
;
return
mData
[
mSize
-
1
]
;
static_assert
(
mSize
>
0
,
"wrong!"
)
;
return
Get
(
Number
<
mSize
-
1
>
{})
;
}
__host__
__device__
static
constexpr
auto
PopFront
();
...
...
composable_kernel/include/utility/integral_constant.hpp
View file @
37b82b7e
...
...
@@ -13,30 +13,64 @@ struct integral_constant
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
template
<
class
X
,
class
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
return
integral_constant
<
T
,
X
+
Y
>
{};
}
};
template
<
class
T
,
T
X
,
T
Y
>
__host__
__device__
constexpr
auto
operator
*
(
integral_constant
<
T
,
X
>
,
integral_constant
<
T
,
Y
>
)
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
return
integral_constant
<
T
,
X
*
Y
>
{};
}
};
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
class
X
,
class
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
};
return
Number
<
X
+
Y
>
{};
}
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
{
};
static_assert
(
Y
<=
X
,
"wrong!"
);
return
Number
<
X
-
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
/
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
%
Y
>
{};
}
#if 0
static constexpr Number<0> 0_c;
static constexpr Number<1> 1_c;
static constexpr Number<2> 2_c;
static constexpr Number<3> 3_c;
static constexpr Number<4> 4_c;
static constexpr Number<5> 5_c;
static constexpr Number<6> 6_c;
static constexpr Number<7> 7_c;
static constexpr Number<8> 8_c;
static constexpr Number<9> 9_c;
#endif
}
// namespace ck
#endif
composable_kernel/include/utility/math.hpp
View file @
37b82b7e
...
...
@@ -42,20 +42,16 @@ struct integer_divide_ceiler
}
};
template
<
class
T
>
__host__
__device__
constexpr
T
integer_divide_ceil
(
T
a
,
T
b
)
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
{
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
return
(
x
+
y
-
1
)
/
y
;
}
template
<
class
T
>
__host__
__device__
constexpr
T
integer_least_multiple
(
T
a
,
T
b
)
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
{
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
b
*
integer_divide_ceil
(
a
,
b
);
return
y
*
integer_divide_ceil
(
x
,
y
);
}
template
<
class
T
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment