Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
874a78f9
Commit
874a78f9
authored
Feb 09, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
6368be50
2fd6c6d4
Changes
89
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1075 additions
and
593 deletions
+1075
-593
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+30
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+65
-0
include/ck/utility/is_known_at_compile_time.hpp
include/ck/utility/is_known_at_compile_time.hpp
+7
-1
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+141
-49
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+137
-3
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+289
-200
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+58
-23
include/ck/wrapper/utils/tensor_partition.hpp
include/ck/wrapper/utils/tensor_partition.hpp
+143
-233
include/ck/wrapper/utils/tensor_utils.hpp
include/ck/wrapper/utils/tensor_utils.hpp
+37
-74
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+3
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+3
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
...erence_tensor_operation/cpu/reference_image_to_column.hpp
+4
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
...brary/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
+1
-2
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
...ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
+1
-2
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp
...ensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp
+64
-0
library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp
...ensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp
+83
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
...e_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
+3
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
874a78f9
...
@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
...
include/ck/utility/data_type.hpp
View file @
874a78f9
...
@@ -189,6 +189,7 @@ struct vector_type<T, 1>
...
@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
}
};
};
int
static
err
=
0
;
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
>
{
{
...
@@ -221,6 +222,10 @@ struct vector_type<T, 2>
...
@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
{
return
data_
.
d2x1_
;
return
data_
.
d2x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -236,6 +241,10 @@ struct vector_type<T, 2>
...
@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
{
return
data_
.
d2x1_
;
return
data_
.
d2x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -278,6 +287,10 @@ struct vector_type<T, 4>
...
@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
{
return
data_
.
d4x1_
;
return
data_
.
d4x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -298,6 +311,10 @@ struct vector_type<T, 4>
...
@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
{
return
data_
.
d4x1_
;
return
data_
.
d4x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -347,6 +364,10 @@ struct vector_type<T, 8>
...
@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
{
return
data_
.
d8x1_
;
return
data_
.
d8x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -372,6 +393,10 @@ struct vector_type<T, 8>
...
@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
{
return
data_
.
d8x1_
;
return
data_
.
d8x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -428,6 +453,10 @@ struct vector_type<T, 16>
...
@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
{
return
data_
.
d16x1_
;
return
data_
.
d16x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -458,6 +487,10 @@ struct vector_type<T, 16>
...
@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
{
return
data_
.
d16x1_
;
return
data_
.
d16x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -520,6 +553,10 @@ struct vector_type<T, 32>
...
@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
{
return
data_
.
d32x1_
;
return
data_
.
d32x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -554,6 +591,10 @@ struct vector_type<T, 32>
...
@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
{
return
data_
.
d32x1_
;
return
data_
.
d32x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -623,6 +664,10 @@ struct vector_type<T, 64>
...
@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
{
return
data_
.
d64x1_
;
return
data_
.
d64x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -662,6 +707,10 @@ struct vector_type<T, 64>
...
@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
{
return
data_
.
d64x1_
;
return
data_
.
d64x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -737,6 +786,10 @@ struct vector_type<T, 128>
...
@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
{
return
data_
.
d128x1_
;
return
data_
.
d128x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -780,6 +833,10 @@ struct vector_type<T, 128>
...
@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
{
return
data_
.
d128x1_
;
return
data_
.
d128x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
@@ -861,6 +918,10 @@ struct vector_type<T, 256>
...
@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
{
return
data_
.
d256x1_
;
return
data_
.
d256x1_
;
}
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
...
@@ -908,6 +969,10 @@ struct vector_type<T, 256>
...
@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
{
return
data_
.
d256x1_
;
return
data_
.
d256x1_
;
}
}
else
{
return
err
;
}
}
}
};
};
...
...
include/ck/utility/is_known_at_compile_time.hpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
...
@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static
constexpr
bool
value
=
false
;
static
constexpr
bool
value
=
false
;
};
};
template
<
>
struct
is_known_at_compile_time
<
unsigned
int
>
{
static
constexpr
bool
value
=
false
;
};
template
<
>
template
<
>
struct
is_known_at_compile_time
<
long_index_t
>
struct
is_known_at_compile_time
<
long_index_t
>
{
{
...
...
include/ck/wrapper/layout.hpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -14,22 +14,28 @@ namespace wrapper {
...
@@ -14,22 +14,28 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Un
nest
edDescriptorType Tensor descriptor for unnested shape dims.
* \tparam Un
roll
edDescriptorType Tensor descriptor for unnested shape dims.
*/
*/
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
struct
Layout
struct
Layout
{
{
private:
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// Generate default idxs tuple (idx with all merged nested shapes)
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
)
{
[
&
](
auto
)
{
if
constexpr
(
!
Unnest
edDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
remove_cvref_t
<
Unroll
edDescriptorType
>
::
IsKnownAtCompileTime
())
{
{
// runtime layout
// runtime layout
return
index_t
(
0
);
return
index_t
(
0
);
...
@@ -43,11 +49,18 @@ struct Layout
...
@@ -43,11 +49,18 @@ struct Layout
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
/**
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
* \brief Generate lower dims in compile-time for the Merge transform using
// If tuple is element, then pass through (sequence with one element)
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template
<
typename
Idx
,
typename
...
Ts
>
template
<
typename
Idx
,
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateLowerDim
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
static
auto
GenerateLowerDim
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
{
if
constexpr
(
Idx
::
value
==
0
)
if
constexpr
(
Idx
::
value
==
0
)
{
{
...
@@ -87,11 +100,17 @@ struct Layout
...
@@ -87,11 +100,17 @@ struct Layout
}
}
}
}
// Iterate over nested tuples in shape
/**
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* \brief Iterate over the nested tuples in the shape.
// Example idx: (1, 1), 1, 1
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example shape: (2, (2, 2)), 2, (2, 2)
* Example idx: (1, 1), 1, 1
// Unrolled shape: 2, (2, 2), 2, (2, 2)
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
AlignShapeToIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
AlignShapeToIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
const
Tuple
<
IdxDims
...
>&
idx
)
...
@@ -126,6 +145,13 @@ struct Layout
...
@@ -126,6 +145,13 @@ struct Layout
}
}
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
DescriptorToMerge
&
desc
)
const
DescriptorToMerge
&
desc
)
...
@@ -137,18 +163,41 @@ struct Layout
...
@@ -137,18 +163,41 @@ struct Layout
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
// Merge to 1d
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform_v1_carry_check
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
// Merge nested shape dims when corresponding index is also nested.
/**
// Input desc shape: 2, 2, 2, 2, 2, 2
* \brief Merge nested shape dims when corresponding index is also merged.
// Example idx: 1, 1, 1, 1
* Input desc shape: 2, 2, 2, 2, 2, 2
// Example shape: 2, (2, 2), 2, (2, 2)
* Example idx: 1, 1, 1, (1, 1)
// Merged shape: 2, 4, 2, 4
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
__host__
__device__
constexpr
static
auto
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
CreateMergedDescriptor
(
const
Tuple
<
ShapeDims
...
>&
shape
,
[[
maybe_unused
]]
const
Tuple
<
IdxDims
...
>&
idxs
,
DescriptorToMerge
&
desc
)
{
{
const
auto
transforms
=
generate_tuple
(
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -160,9 +209,19 @@ struct Layout
...
@@ -160,9 +209,19 @@ struct Layout
// If shape element is tuple and idx element is Number, then merge
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
// Unroll and reverse tuple to traverse column-major
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
make_merge_transform
(
merge_elems
);
return
make_merge_transform
(
merge_elems
);
}
}
else
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return
make_merge_transform_v1_carry_check
(
merge_elems
);
}
}
else
{
{
// If shape element is integer and idx element is tuple, passed idx is wrong
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert
(
static_assert
(
...
@@ -185,14 +244,23 @@ struct Layout
...
@@ -185,14 +244,23 @@ struct Layout
}
}
using
Descriptor1dType
=
using
Descriptor1dType
=
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
Un
nest
edDescriptorType
{}))
>
;
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
Un
roll
edDescriptorType
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
public:
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
__host__
__device__
constexpr
static
auto
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
,
const
Tuple
<
IdxDims
...
>&
idx
s
,
const
Un
nest
edDescriptorType
&
naive_descriptor
)
const
Un
roll
edDescriptorType
&
naive_descriptor
)
{
{
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
{
{
...
@@ -208,19 +276,18 @@ struct Layout
...
@@ -208,19 +276,18 @@ struct Layout
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
// Unroll while IdxDims is nested
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idx
);
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idx
s
);
// Transform correct form of shape
// Transform correct form of shape
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
),
naive_descriptor
);
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
s
),
naive_descriptor
);
}
}
}
}
using
MergedNestsDescriptorType
=
remove_cvref_t
<
decltype
(
TransformDesc
(
using
MergedNestsDescriptorType
=
remove_cvref_t
<
decltype
(
TransformDesc
(
Shape
{},
DefaultIdxsTupleType
{},
Un
nest
edDescriptorType
{}))
>
;
Shape
{},
DefaultIdxsTupleType
{},
Un
roll
edDescriptorType
{}))
>
;
public:
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
{
return
un
nest
ed_descriptor_
.
GetElementSpaceSize
();
return
un
roll
ed_descriptor_
.
GetElementSpaceSize
();
}
}
__host__
__device__
Layout
()
=
delete
;
__host__
__device__
Layout
()
=
delete
;
...
@@ -232,16 +299,15 @@ struct Layout
...
@@ -232,16 +299,15 @@ struct Layout
* \param unnested_descriptor Descriptor
* \param unnested_descriptor Descriptor
*/
*/
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
const
Un
nest
edDescriptorType
&
unnested_descriptor
)
const
Un
roll
edDescriptorType
&
unnested_descriptor
)
:
shape_
(
shape
)
:
unrolled_descriptor_
(
unnested_descriptor
),
shape_
(
shape
)
{
{
// Construct if runtime mode
// Construct if runtime mode
if
constexpr
(
!
Unnest
edDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
remove_cvref_t
<
Unroll
edDescriptorType
>
::
IsKnownAtCompileTime
())
{
{
unnested_descriptor_
=
unnested_descriptor
;
descriptor_1d_
=
MakeMerge1d
(
shape_
,
unrolled_descriptor_
);
descriptor_1d_
=
MakeMerge1d
(
shape_
,
unnested_descriptor_
);
merged_nests_descriptor_
=
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
un
nest
ed_descriptor_
);
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
un
roll
ed_descriptor_
);
}
}
}
}
...
@@ -254,9 +320,9 @@ struct Layout
...
@@ -254,9 +320,9 @@ struct Layout
template
<
typename
Idxs
>
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
__host__
__device__
constexpr
index_t
operator
()()
const
{
{
static_assert
(
Unnest
edDescriptorType
::
IsKnownAtCompileTime
(),
static_assert
(
remove_cvref_t
<
Unroll
edDescriptorType
>
::
IsKnownAtCompileTime
(),
"Compiletime operator used on runtime layout."
);
"Compiletime operator used on runtime layout."
);
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
Un
nest
edDescriptorType
{}));
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
Un
roll
edDescriptorType
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
}
...
@@ -283,7 +349,7 @@ struct Layout
...
@@ -283,7 +349,7 @@ struct Layout
else
else
{
{
// Custom index, need to transform descriptor
// Custom index, need to transform descriptor
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
un
nest
ed_descriptor_
);
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
un
roll
ed_descriptor_
);
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
}
}
...
@@ -350,29 +416,55 @@ struct Layout
...
@@ -350,29 +416,55 @@ struct Layout
}
}
/**
/**
* \brief Get default descriptor (with the same size as Shape)
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
*
* \return Default descriptor.
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return Merged nests descriptor.
*/
*/
__host__
__device__
constexpr
const
MergedNestsDescriptorType
&
GetDefaultDescriptor
()
const
__host__
__device__
constexpr
const
MergedNestsDescriptorType
&
GetMergedNestingDescriptor
()
const
{
{
return
merged_nests_descriptor_
;
return
merged_nests_descriptor_
;
}
}
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__
__device__
constexpr
const
Descriptor1dType
&
Get1DDescriptor
()
const
{
return
descriptor_1d_
;
}
/**
/**
* \brief Get unnested descriptor (with unrolled dims)
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
*
* \return Flatten descriptor.
* \return Flatten
ed
descriptor.
*/
*/
__host__
__device__
constexpr
const
Un
nest
edDescriptorType
&
GetUn
nest
edDescriptor
()
const
__host__
__device__
constexpr
const
Un
roll
edDescriptorType
&
GetUn
roll
edDescriptor
()
const
{
{
return
un
nest
ed_descriptor_
;
return
un
roll
ed_descriptor_
;
}
}
private:
private:
UnnestedDescriptorType
unnested_descriptor_
;
// All dimensions are unrolled
UnrolledDescriptorType
unrolled_descriptor_
;
// 1D descriptor
Descriptor1dType
descriptor_1d_
;
Descriptor1dType
descriptor_1d_
;
// All nesting are merged
MergedNestsDescriptorType
merged_nests_descriptor_
;
MergedNestsDescriptorType
merged_nests_descriptor_
;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const
Shape
shape_
;
const
Shape
shape_
;
};
};
...
...
include/ck/wrapper/operations/copy.hpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "../utils/tensor_utils.hpp"
#include "../utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/**
/**
* \brief Perform generic copy between two tensors
. Tensors must have the
* \brief Perform generic copy between two tensors
partitions (threadwise copy).
* same size.
*
Tensors must have the
same size.
*
*
* \param src_tensor Source tensor.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
* \param dst_tensor Destination tensor.
...
@@ -37,5 +42,134 @@ __host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& ds
...
@@ -37,5 +42,134 @@ __host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& ds
}
}
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
DimAccessOrderTuple
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
>
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
auto
&
in_grid_desc
=
layout
(
src_tensor
).
GetUnrolledDescriptor
();
const
auto
&
out_grid_desc
=
layout
(
dst_tensor
).
GetUnrolledDescriptor
();
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
constexpr
auto
thread_slice_lengths
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform a copy between DynamicBuffers
auto
transfer
=
ThreadwiseTensorSliceTransfer_v7
<
Tuple
<
typename
SrcTensorType
::
TensorElementType
>
,
Tuple
<
typename
DstTensorType
::
TensorElementType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
Sequence
<
false
>
,
Sequence
<
false
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
tie
(
out_grid_desc
),
tie
(
dst_tensor
.
GetBuffer
()));
}
else
if
constexpr
(
!
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from StaticBuffer to DynamicBuffer
const
auto
src_slice_origin_idxs
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v1r3
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
I1
,
true
>
{
out_grid_desc
,
dst_tensor
.
GetMultiIdxOffsets
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
in_grid_desc
,
src_slice_origin_idxs
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
dst_tensor
.
GetBuffer
());
}
else
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
!
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from DynamicBuffer to StaticBuffer
const
auto
src_dst_slice_origin
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
constexpr
auto
src_vector_tensor_lengths
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
if
constexpr
(
I
==
VectorDim
)
{
return
Number
<
ScalarPerVector
>
{};
}
else
{
return
I1
;
}
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v4r1
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
decltype
(
src_vector_tensor_lengths
),
decltype
(
dim_access_order
)
>
{
src_tensor
.
GetMultiIdxOffsets
()};
transfer
.
Run
(
in_grid_desc
,
src_dst_slice_origin
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
src_dst_slice_origin
,
dst_tensor
.
GetBuffer
());
}
else
{
// Perform copy between StaticBuffers
copy
(
src_tensor
,
dst_tensor
);
}
}
}
// namespace wrapper
}
// namespace wrapper
}
// namespace ck
}
// namespace ck
include/ck/wrapper/tensor.hpp
View file @
874a78f9
This diff is collapsed.
Click to expand it.
include/ck/wrapper/utils/layout_utils.hpp
View file @
874a78f9
This diff is collapsed.
Click to expand it.
include/ck/wrapper/utils/tensor_partition.hpp
View file @
874a78f9
This diff is collapsed.
Click to expand it.
include/ck/wrapper/utils/tensor_utils.hpp
View file @
874a78f9
This diff is collapsed.
Click to expand it.
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
874a78f9
...
@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
...
@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return
0
;
return
0
;
}
}
throw
std
::
runtime_error
(
"Col2Img: number of dimensions should be between 1 and 3."
);
return
1
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
874a78f9
...
@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return
0
;
return
0
;
}
}
throw
std
::
runtime_error
(
"Conv_bwd_data: number of dimensions must be between 1 and 3."
);
return
1
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
874a78f9
...
@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return
0
;
return
0
;
}
}
throw
std
::
runtime_error
(
"Conv_bwd: number of dimensions must be between 1 and 3."
);
return
1
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
874a78f9
...
@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator
return
0
;
return
0
;
}
}
throw
std
::
runtime_error
(
"Conv_fwd: number of dimensions must be between 1 and 3."
);
return
1
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
874a78f9
...
@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
AccDataType
v_acc
=
0
;
ComputeTypeA
v_a
=
0
;
ComputeTypeB
v_b
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
ComputeTypeA
v_a
;
ComputeTypeB
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
...
@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
CDataType
v_c
;
CDataType
v_c
=
0
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
...
@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
return
0
;
return
0
;
}
}
throw
std
::
runtime_error
(
"Img2Col: number of dimensions should be between 1 and 3."
);
return
1
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
View file @
874a78f9
...
@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory<
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
#endif
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
View file @
874a78f9
...
@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
...
@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
#endif
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp
0 → 100644
View file @
874a78f9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_groupnorm_bwd_gamma_beta_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdGammaBeta
<
F32
,
F32
,
F32
,
F32
,
F32
,
5
,
3
>>>&
);
#endif
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
5
,
3
>>
{
using
DeviceOp
=
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
5
,
3
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
DYDataType
,
F32
>
&&
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
MeanInvStdDataType
,
F32
>
&&
is_same_v
<
DGammaDataType
,
F32
>
&&
is_same_v
<
DBetaDataType
,
F32
>
)
{
add_device_groupnorm_bwd_gamma_beta_f32_instances
(
op_ptrs
);
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp
0 → 100644
View file @
874a78f9
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
View file @
874a78f9
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
...
@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
2
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
// pipeline v1, 2 waves
,
,
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment