Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2fd6c6d4
Commit
2fd6c6d4
authored
Jan 31, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c32d3448
6651a124
Changes
78
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
571 additions
and
352 deletions
+571
-352
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+58
-23
include/ck/wrapper/utils/tensor_partition.hpp
include/ck/wrapper/utils/tensor_partition.hpp
+143
-233
include/ck/wrapper/utils/tensor_utils.hpp
include/ck/wrapper/utils/tensor_utils.hpp
+37
-74
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+3
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+3
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
...erence_tensor_operation/cpu/reference_image_to_column.hpp
+4
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
...brary/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
+1
-2
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
...ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
+1
-2
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp
...ensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp
+64
-0
library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp
...ensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp
+83
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
...e_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
+3
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
...e_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
+3
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
...vice_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
+3
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
...vice_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
+3
-1
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
.../device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
+73
-8
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
.../device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
+82
-0
library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp
...a_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp
+1
-1
No files found.
include/ck/wrapper/utils/layout_utils.hpp
View file @
2fd6c6d4
...
...
@@ -22,14 +22,19 @@ namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
struct
Layout
;
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
namespace
{
// Generate packed (column-major) strides if not passed
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
shape
)
...
...
@@ -50,9 +55,16 @@ GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
__host__
__device__
constexpr
auto
Make
Flatten
Descriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
__host__
__device__
constexpr
auto
Make
Unrolled
Descriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
if
constexpr
(
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
...
...
@@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
using
Un
nest
edDescriptorType
=
decltype
(
Make
Flatten
Descriptor
(
Shape
{},
Strides
{}));
return
Layout
<
Shape
,
Un
nest
edDescriptorType
>
(
shape
,
Make
Flatten
Descriptor
(
shape
,
strides
));
using
Un
roll
edDescriptorType
=
decltype
(
Make
Unrolled
Descriptor
(
Shape
{},
Strides
{}));
return
Layout
<
Shape
,
Un
roll
edDescriptorType
>
(
shape
,
Make
Unrolled
Descriptor
(
shape
,
strides
));
}
/**
...
...
@@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template
<
typename
Shape
>
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
)
{
using
Un
nest
edDescriptorType
=
decltype
(
Make
Flatten
Descriptor
(
Shape
{},
Tuple
<>
{}));
return
Layout
<
Shape
,
Un
nest
edDescriptorType
>
(
shape
,
Make
Flatten
Descriptor
(
shape
,
Tuple
<>
{}));
using
Un
roll
edDescriptorType
=
decltype
(
Make
Unrolled
Descriptor
(
Shape
{},
Tuple
<>
{}));
return
Layout
<
Shape
,
Un
roll
edDescriptorType
>
(
shape
,
Make
Unrolled
Descriptor
(
shape
,
Tuple
<>
{}));
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
get
(
const
T
&
dim
)
...
...
@@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
},
Number
<
old_shape_dims
>
{});
const
auto
&
flatten_desc
=
layout
.
GetUn
nest
edDescriptor
();
const
auto
&
flatten_desc
=
layout
.
GetUn
roll
edDescriptor
();
auto
new_desc
=
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
return
Layout
<
decltype
(
new_shape
),
decltype
(
new_desc
)
>
(
new_shape
,
new_desc
);
}
...
...
@@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
}
// size
// Get dim size (could be returned from get function)
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
...
...
@@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Un
nest
edDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
Un
nest
edDescriptorType
>&
layout
)
template
<
index_t
idx
,
typename
Shape
,
typename
Un
roll
edDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
Un
roll
edDescriptorType
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
...
...
@@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
Un
nest
edDescriptorType
>&
layout
)
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
Un
roll
edDescriptorType
>&
layout
)
{
return
layout
.
GetLengths
();
}
...
...
@@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
Un
nest
edDescriptorType
>&
layout
)
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
Un
roll
edDescriptorType
>&
layout
)
{
return
Shape
::
Size
();
}
...
...
@@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
rank
(
const
Number
<
IDim
>&
)
__host__
__device__
constexpr
index_t
rank
(
[[
maybe_unused
]]
const
Number
<
IDim
>&
dim
)
{
return
1
;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__
__device__
constexpr
index_t
rank
(
const
index_t
&
)
{
return
1
;
}
__host__
__device__
constexpr
index_t
rank
(
[[
maybe_unused
]]
const
index_t
&
dim
)
{
return
1
;
}
/**
* \brief Hierarchical rank.
...
...
@@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
Un
nest
edDescriptorType
>&
layout
)
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
Un
roll
edDescriptorType
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
return
TupleDepth
(
shape
);
...
...
@@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
depth
(
const
Number
<
IDim
>&
)
__host__
__device__
constexpr
index_t
depth
(
[[
maybe_unused
]]
const
Number
<
IDim
>&
dim
)
{
return
0
;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__
__device__
constexpr
index_t
depth
(
const
index_t
&
)
{
return
0
;
}
__host__
__device__
constexpr
index_t
depth
(
[[
maybe_unused
]]
const
index_t
&
dim
)
{
return
0
;
}
/**
* \brief Hierarchical depth.
...
...
include/ck/wrapper/utils/tensor_partition.hpp
View file @
2fd6c6d4
...
...
@@ -6,12 +6,22 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
namespace
ck
{
namespace
wrapper
{
namespace
{
// Calculate shape for partition based on number of threads per each dim and
// previous shape
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
*
* \param shape Base tensor shape.
* \param thread_lengths Tuple of thread lengths.
* \return Partition shape.
*/
template
<
typename
...
Ts
,
typename
...
Ls
>
__host__
__device__
constexpr
auto
CalculateLocalPartitionShape
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
thread_lengths
)
...
...
@@ -20,265 +30,165 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if tuple then recurrence
return
CalculateLocalPartitionShape
(
shape
.
At
(
num_i
),
thread_lengths
.
At
(
num_i
));
}
else
{
const
auto
slice_len
=
shape
.
At
(
num_i
)
/
thread_lengths
.
At
(
num_i
);
return
slice_len
;
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Calculate shape for partition based on number of threads per each dim,
// previous strides and steps
template
<
typename
...
Ts
,
typename
...
Ls
,
typename
...
Steps
,
typename
FlattenDescType
>
__host__
__device__
constexpr
auto
CalculateLocalPartitionDescriptor
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
thread_lengths
,
const
Tuple
<
Steps
...
>&
steps
,
const
FlattenDescType
&
flatten_desc
)
{
static_assert
(
Tuple
<
Ts
...
>::
Size
()
==
Tuple
<
Ls
...
>::
Size
(),
"Wrong thread_lengths shape."
);
const
auto
unrolled_thread_lengths
=
UnrollNestedTuple
(
thread_lengths
);
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
constexpr
auto
dims
=
decltype
(
unrolled_thread_lengths
)
::
Size
();
using
UnrolledStepsType
=
decltype
(
UnrollNestedTuple
(
steps
));
using
I1
=
Number
<
1
>
;
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_same_v
<
Tuple
<
Steps
...
>
,
Tuple
<>>
)
{
// By default raked partition
const
auto
partition_stride
=
unrolled_thread_lengths
.
At
(
num_i
);
return
make_embed_transform
(
make_tuple
(
unrolled_shape
.
At
(
num_i
)),
make_tuple
(
partition_stride
));
}
else
if
constexpr
(
!
is_same_v
<
tuple_element_t
<
i
.
value
,
UnrolledStepsType
>
,
index_t
>
)
{
// Compiletime partition
if
constexpr
(
is_same_v
<
tuple_element_t
<
i
.
value
,
UnrolledStepsType
>
,
I1
>
)
{
// raked
const
auto
partition_stride
=
unrolled_thread_lengths
.
At
(
num_i
);
return
make_embed_transform
(
make_tuple
(
unrolled_shape
.
At
(
num_i
)),
make_tuple
(
partition_stride
));
}
else
{
// packed
return
make_embed_transform
(
make_tuple
(
unrolled_shape
.
At
(
num_i
)),
make_tuple
(
I1
{}));
}
}
else
{
// Runtime partition
if
(
steps
.
At
(
num_i
)
==
1
)
{
// raked
const
auto
partition_stride
=
unrolled_thread_lengths
.
At
(
num_i
);
return
make_embed_transform
(
make_tuple
(
unrolled_shape
.
At
(
num_i
)),
make_tuple
(
partition_stride
));
}
else
{
// packed
return
make_embed_transform
(
make_tuple
(
unrolled_shape
.
At
(
num_i
)),
make_tuple
(
I1
{}));
}
}
},
Number
<
dims
>
{});
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
dims
>
{});
const
auto
upper_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
dims
>
{});
return
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
}
template
<
typename
...
Ls
,
typename
...
Steps
>
__host__
__device__
constexpr
auto
CalculateLayoutOffsetIdxImpl
(
const
Tuple
<
Ls
...
>&
thread_lengths
,
const
Tuple
<
Steps
...
>&
steps
,
index_t
&
thread_id
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ls
...
>>>::
value
)
{
// if tuple then recurrence
if
constexpr
(
is_same_v
<
Tuple
<
Steps
...
>
,
Tuple
<>>
)
{
return
CalculateLayoutOffsetIdxImpl
(
thread_lengths
.
At
(
num_i
),
Tuple
<>
{},
thread_id
);
}
else
{
return
CalculateLayoutOffsetIdxImpl
(
thread_lengths
.
At
(
num_i
),
steps
.
At
(
num_i
),
thread_id
);
}
}
else
{
// Update thread_id after each dim
const
auto
dim_thread_id
=
thread_id
%
thread_lengths
.
At
(
num_i
);
thread_id
/=
thread_lengths
.
At
(
num_i
);
if
constexpr
(
is_same_v
<
Tuple
<
Steps
...
>
,
Tuple
<>>
)
{
return
dim_thread_id
;
}
else
{
// Apply step
return
steps
.
At
(
num_i
)
*
dim_thread_id
;
}
}
const
auto
slice_len
=
size
<
num_i
>
(
shape
)
/
thread_lengths
.
At
(
num_i
);
return
slice_len
;
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
}
// Convert integer thread_idx to tuple index with steps applied
template
<
typename
...
Ls
,
typename
...
Steps
>
__host__
__device__
constexpr
auto
CalculateLayoutOffsetIdx
(
const
Tuple
<
Ls
...
>&
thread_lengths
,
const
Tuple
<
Steps
...
>&
steps
,
const
index_t
thread_id
)
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \return Tuple with blocks number.
*/
template
<
typename
...
Ts
,
typename
...
Ls
>
__host__
__device__
constexpr
auto
CalculateGridSize
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
tile_shape
)
{
// Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates
index_t
thread_id_copy
=
thread_id
;
return
CalculateLayoutOffsetIdxImpl
(
thread_lengths
,
steps
,
thread_id_copy
);
static_assert
(
Tuple
<
Ts
...
>::
Size
()
==
Tuple
<
Ls
...
>::
Size
(),
"Wrong thread_lengths shape."
);
return
generate_tuple
([
&
](
auto
i
)
{
return
size
<
i
>
(
shape
)
/
size
<
i
>
(
tile_shape
);
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{}
);
}
// Apply steps to index represented as tuple
template
<
typename
...
Steps
,
typename
...
Idxs
>
__host__
__device__
constexpr
auto
CalculateLayoutOffsetIdx
(
const
Tuple
<
Steps
...
>&
steps
,
const
Tuple
<
Idxs
...
>&
block_idxs
)
/**
* \brief Calculate scaled offset for new partition/tile.
*
* \param thread_idxs Thread 1d id.
* \param partition_lengths_seq Sequence of partition shape.
* \param old_offset_idxs Multi index offset from base tensor to shift values.
* \return Partition shape.
*/
template
<
typename
ThreadIdxs
,
typename
PartitionLengthsSeq
,
typename
OldOffsetIdxs
>
__host__
__device__
constexpr
auto
CalculateOffsetMultiIdxs
(
const
ThreadIdxs
&
thread_idxs
,
const
PartitionLengthsSeq
&
partition_lengths_seq
,
const
OldOffsetIdxs
&
old_offset_idxs
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Idxs
...
>>>::
value
)
{
// if tuple then recurrence
if
constexpr
(
is_same_v
<
Tuple
<
Steps
...
>
,
Tuple
<>>
)
{
return
CalculateLayoutOffsetIdx
(
Tuple
<>
{},
block_idxs
.
At
(
num_i
));
}
else
{
return
CalculateLayoutOffsetIdx
(
steps
.
At
(
num_i
),
block_idxs
.
At
(
num_i
));
}
}
else
{
if
constexpr
(
is_same_v
<
Tuple
<
Steps
...
>
,
Tuple
<>>
)
{
return
block_idxs
.
At
(
num_i
);
}
else
{
// apply step
return
steps
.
At
(
num_i
)
*
block_idxs
.
At
(
num_i
);
}
}
},
Number
<
Tuple
<
Idxs
...
>::
Size
()
>
{});
return
thread_idxs
*
partition_lengths_seq
+
old_offset_idxs
;
}
// User passes only shape per block to the make_local_tile function. This function calculates
// block layout based on the shape.
template
<
typename
...
Ts
,
typename
...
BlockDims
>
__host__
__device__
constexpr
auto
CalculateBlockLengths
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
BlockDims
...
>&
tile_shape
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if tuple then recurrence
return
CalculateBlockLengths
(
shape
.
At
(
num_i
),
tile_shape
.
At
(
num_i
));
}
else
{
return
shape
.
At
(
num_i
)
/
tile_shape
.
At
(
num_i
);
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
}
// namespace
/**
* \brief Create local partition for thread.
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads.
* \param thread_lengths Layout of threads
(could not be nested)
.
* \param thread_id Thread index represented as integer.
* \param steps Thread step (default=1, raked partition)
* \return Partition tensor.
*/
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
,
typename
StepsTuple
=
Tuple
<
>
>
__host__
__device__
constexpr
auto
make_local_partition
(
const
TensorType
&
tensor
,
const
ThreadLengthsTuple
&
thread_lengths
,
const
index_t
thread_
id
,
const
StepsTuple
steps
=
StepsTuple
{}
)
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
[[
maybe_unused
]]
const
ThreadLengthsTuple
&
thread_
lengths
,
const
index_t
thread_id
)
{
// Create shape, strides and layout for new partition tensor
const
auto
partition_shape
=
CalculateLocalPartitionShape
(
shape
(
tensor
),
thread_lengths
);
// Create new descriptor and layout
const
auto
&
flatten_desc
=
layout
(
tensor
).
GetUnnestedDescriptor
();
auto
partition_desc
=
CalculateLocalPartitionDescriptor
(
shape
(
tensor
),
thread_lengths
,
steps
,
flatten_desc
);
const
auto
partition_layout
=
Layout
<
decltype
(
partition_shape
),
decltype
(
partition_desc
)
>
(
partition_shape
,
partition_desc
);
// Calculate offset for new partition tensor
const
auto
offset_idx
=
CalculateLayoutOffsetIdx
(
thread_lengths
,
steps
,
thread_id
);
const
auto
partition_offset
=
layout
(
tensor
)(
offset_idx
);
return
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
()
+
partition_offset
,
partition_layout
);
static_assert
(
!
IsNestedTuple
(
ThreadLengthsTuple
{}));
// Calculate new partition shape
const
auto
&
tensor_shape
=
shape
(
tensor
);
constexpr
auto
partition_shape
=
CalculateLocalPartitionShape
(
decltype
(
tensor_shape
){},
ThreadLengthsTuple
{});
// Create Thread Cluster Descriptor
constexpr
auto
partition_lengths_seq
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
return
size
<
I
>
(
partition_shape
);
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
thread_lengths_seq
=
generate_sequence_v2
([
&
](
auto
I
)
{
return
size
<
I
>
(
ThreadLengthsTuple
{});
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
thread_lengths_seq
);
// Calculate thread idxs and offsets
const
auto
thread_idxs
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
const
auto
offset_multi_idxs
=
CalculateOffsetMultiIdxs
(
thread_idxs
,
partition_lengths_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
auto
&
flatten_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
flatten_desc
)
>
(
partition_shape
,
flatten_desc
);
auto
partition_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
partition_layout
);
// Apply offsets
partition_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
partition_tensor
;
}
/**
* \brief Create local tile for thread block.
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Temporary to gain the best performance use 2d
* tile_shape.
*
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id
x
Block index represented as
tuple
.
* \param steps Block step (default=1, raked partition)
* \param block_id Block index represented as
integer
.
* \return Tile tensor.
*/
template
<
typename
TensorType
,
typename
BlockShapeTuple
,
typename
BlockIdxTuple
,
typename
StepsTuple
=
Tuple
<
>
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
BlockIdxTuple
&
block_idx
,
const
StepsTuple
steps
=
StepsTuple
{})
template
<
typename
TensorType
,
typename
BlockShapeTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
)
{
// Create block lengths, strides and layout for new tile tensor
const
auto
block_lengths
=
CalculateBlockLengths
(
shape
(
tensor
),
tile_shape
);
// Create new descriptor and layout
const
auto
&
flatten_desc
=
layout
(
tensor
).
GetUnnestedDescriptor
();
auto
tile_desc
=
CalculateLocalPartitionDescriptor
(
tile_shape
,
block_lengths
,
steps
,
flatten_desc
);
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
tile_shape
)
>
,
decltype
(
tile_desc
)
>
(
tile_shape
,
tile_desc
);
// Calculate offset for new partition tensor
const
auto
offset_idx
=
CalculateLayoutOffsetIdx
(
steps
,
block_idx
);
const
auto
tile_offset
=
layout
(
tensor
)(
offset_idx
);
return
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
()
+
tile_offset
,
tile_layout
);
static_assert
(
!
IsNestedTuple
(
BlockShapeTuple
{}));
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
&
aligned_desc
=
layout
(
tensor
).
GetMergedNestingDescriptor
();
if
constexpr
(
BlockShapeTuple
::
Size
()
==
I2
)
{
// Optimized version for 2d tile shape [MxK]
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
BlockShapeTuple
{}.
At
(
I0
),
BlockShapeTuple
{}.
At
(
I1
),
remove_cvref_t
<
decltype
(
aligned_desc
)
>>
(
aligned_desc
);
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
size
<
0
>
(
tile_shape
));
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
size
<
1
>
(
tile_shape
));
const
auto
offset_multi_idxs
=
make_tuple
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
);
// Create new layout and tensor
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
tile_shape
)
>
,
decltype
(
aligned_desc
)
>
(
tile_shape
,
aligned_desc
);
auto
tile_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
tile_layout
);
// Apply offsets
tile_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
tile_tensor
;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr
auto
tile_shape_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
BlockShapeTuple
{}.
At
(
I
));
},
Number
<
BlockShapeTuple
::
Size
()
>
{});
// Tuple with number of blocks
const
auto
block_lengths
=
CalculateGridSize
(
shape
(
tensor
),
tile_shape
);
constexpr
auto
block_cluster_desc_
=
make_cluster_descriptor
(
block_lengths
);
const
auto
block_idxs
=
block_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
const
auto
offset_multi_idxs
=
CalculateOffsetMultiIdxs
(
block_idxs
,
tile_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
tile_shape
)
>
,
decltype
(
aligned_desc
)
>
(
tile_shape
,
aligned_desc
);
auto
tile_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
tile_layout
);
// Apply offsets
tile_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
tile_tensor
;
}
}
}
// namespace wrapper
...
...
include/ck/wrapper/utils/tensor_utils.hpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,6 +10,7 @@
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/multi_index.hpp"
namespace
ck
{
namespace
wrapper
{
...
...
@@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
struct
Layout
;
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnnestedDescriptorType
,
index_t
NumVectors
,
// params for Register memory
index_t
ScalarPerVector
// param for Register memory
>
typename
UnrolledDescriptorType
>
struct
Tensor
;
template
<
typename
FromType
,
typename
ToType
>
...
...
@@ -45,13 +42,22 @@ struct Slice
__host__
__device__
constexpr
Slice
()
:
from_
(),
to_
()
{}
__host__
__device__
constexpr
Slice
(
FromType
from
,
ToType
to
)
:
from_
(
from
),
to_
(
to
)
{}
/**
* \brief Calculate slice range.
*
* \param dim Dimension size.
* \return Slice range.
*/
template
<
typename
T
>
__host__
__device__
constexpr
auto
range
(
const
T
&
dim
)
const
{
if
constexpr
(
is_same_v
<
FromType
,
index_t
>
||
is_same_v
<
ToType
,
index_t
>
||
is_same_v
<
T
,
index_t
>
)
{
assert
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)
&&
"Invalid range"
);
if
(
!
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)))
{
throw
std
::
runtime_error
(
"Invalid range"
);
}
if
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
1
;
...
...
@@ -101,40 +107,27 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
template
<
MemoryTypeEnum
MemoryType
,
typename
ElementType
,
typename
Shape
,
typename
Un
nest
edDescriptorType
>
typename
Un
roll
edDescriptorType
>
constexpr
auto
make_tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
Un
nest
edDescriptorType
>&
layout
)
const
Layout
<
Shape
,
Un
roll
edDescriptorType
>&
layout
)
{
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
UnnestedDescriptorType
,
0
/*NumVectors*/
,
0
/*ScalarPerVector*/
>
(
pointer
,
layout
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
UnrolledDescriptorType
>
(
pointer
,
layout
);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
index_t
NumVectors
,
index_t
ScalarPerVector
,
typename
Element
Type
>
constexpr
auto
make_register_tensor
()
typename
ElementType
,
typename
Shape
,
typename
UnrolledDescriptor
Type
>
constexpr
auto
make_register_tensor
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
const
auto
layout
=
make_layout
(
make_tuple
(
Number
<
NumVectors
>
{}),
make_tuple
(
Number
<
1
>
{}));
return
Tensor
<
MemoryType
,
ElementType
,
Tuple
<
Number
<
NumVectors
>>
,
std
::
remove_const_t
<
remove_reference_t
<
decltype
(
layout
.
GetUnnestedDescriptor
())
>>
,
NumVectors
,
ScalarPerVector
>
(
layout
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
UnrolledDescriptorType
>
(
layout
);
}
/**
...
...
@@ -146,15 +139,9 @@ constexpr auto make_register_tensor()
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnnestedDescriptorType
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
layout
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnnestedDescriptorType
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
const
auto
&
layout
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
tensor
.
GetLayout
();
}
...
...
@@ -170,15 +157,9 @@ template <index_t... Idxs,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnnestedDescriptorType
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
auto
size
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnnestedDescriptorType
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
size
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
...
...
@@ -194,15 +175,9 @@ template <index_t... Idxs,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnnestedDescriptorType
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
auto
rank
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnnestedDescriptorType
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
rank
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
rank
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
...
...
@@ -218,15 +193,9 @@ template <index_t... Idxs,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnnestedDescriptorType
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
auto
depth
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnnestedDescriptorType
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
depth
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
depth
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
...
...
@@ -240,15 +209,9 @@ __host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnnestedDescriptorType
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnnestedDescriptorType
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
shape
(
tensor
.
GetLayout
());
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
2fd6c6d4
...
...
@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Col2Img: number of dimensions should be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
2fd6c6d4
...
...
@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Conv_bwd_data: number of dimensions must be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
2fd6c6d4
...
...
@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Conv_bwd: number of dimensions must be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
2fd6c6d4
...
...
@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Conv_fwd: number of dimensions must be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
2fd6c6d4
...
...
@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
ComputeTypeA
v_a
=
0
;
ComputeTypeB
v_b
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ComputeTypeA
v_a
;
ComputeTypeB
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
...
...
@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
CDataType
v_c
;
CDataType
v_c
=
0
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Img2Col: number of dimensions should be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
View file @
2fd6c6d4
...
...
@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory<
return
op_ptrs
;
}
};
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
View file @
2fd6c6d4
...
...
@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
return
op_ptrs
;
}
};
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp
0 → 100644
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_groupnorm_bwd_gamma_beta_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdGammaBeta
<
F32
,
F32
,
F32
,
F32
,
F32
,
5
,
3
>>>&
);
#endif
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
5
,
3
>>
{
using
DeviceOp
=
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
5
,
3
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
DYDataType
,
F32
>
&&
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
MeanInvStdDataType
,
F32
>
&&
is_same_v
<
DGammaDataType
,
F32
>
&&
is_same_v
<
DBetaDataType
,
F32
>
)
{
add_device_groupnorm_bwd_gamma_beta_f32_instances
(
op_ptrs
);
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp
0 → 100644
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
// FP16
void
add_device_layernorm2d_bwd_gamma_beta_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdGammaBeta
<
F16
,
F16
,
F16
,
F16
,
F16
,
2
,
1
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_layernorm2d_bwd_gamma_beta_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdGammaBeta
<
F32
,
F32
,
F32
,
F32
,
F32
,
2
,
1
>>>&
);
#endif
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceNormalizationBwdGammaBeta
<
DYDataType
,
XDataType
,
MeanInvStdDataType
,
DGammaDataType
,
DBetaDataType
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DYDataType
,
F16
>
&&
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
MeanInvStdDataType
,
F16
>
&&
is_same_v
<
DGammaDataType
,
F16
>
&&
is_same_v
<
DBetaDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
add_device_layernorm2d_bwd_gamma_beta_f16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
DYDataType
,
F32
>
&&
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
MeanInvStdDataType
,
F32
>
&&
is_same_v
<
DGammaDataType
,
F32
>
&&
is_same_v
<
DBetaDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
add_device_layernorm2d_bwd_gamma_beta_f32_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
View file @
2fd6c6d4
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
2
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
View file @
2fd6c6d4
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
2
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
2fd6c6d4
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV2
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
2
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
2fd6c6d4
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
2
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
2fd6c6d4
...
...
@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
@@ -110,17 +111,39 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format on
>
;
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
ck
::
PipelineVersion
PipVer
,
ck
::
LoopScheduler
LoopSche
>
using
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipelineVersion
::
v1
,
LoopScheduler
::
Interwave
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipelineVersion
::
v1
,
LoopScheduler
::
Interwave
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipelineVersion
::
v1
,
LoopScheduler
::
Interwave
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipelineVersion
::
v1
,
LoopScheduler
::
Interwave
>
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
512
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
4
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
256
,
16
,
4
,
8
,
16
,
16
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
4
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
512
,
16
,
4
,
8
,
16
,
16
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
16
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
16
,
4
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
8
,
8
,
16
,
16
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
256
,
16
,
8
,
8
,
16
,
16
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
// clang-format on
>
;
...
...
@@ -141,9 +164,51 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
<
GemmMNKPadding
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmMNKPadding
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
2fd6c6d4
...
...
@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
@@ -95,6 +96,41 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F16
,
PipelineVersion
::
v2
>
// clang-format on
>
;
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
ck
::
PipelineVersion
PipVer
,
ck
::
LoopScheduler
LoopSche
>
using
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
512
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
4
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
256
,
16
,
4
,
8
,
16
,
16
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
4
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
512
,
16
,
4
,
8
,
16
,
16
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
16
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
16
,
4
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
16
,
4
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
8
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
8
,
8
,
16
,
16
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
256
,
16
,
8
,
8
,
16
,
16
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
F16
,
PipVer
,
LoopSche
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
...
...
@@ -112,6 +148,52 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
<
GemmMNKPadding
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp
View file @
2fd6c6d4
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_layernorm2d_bwd_gamma_beta_
rank_2_1_
f16_instances
(
void
add_device_layernorm2d_bwd_gamma_beta_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdGammaBeta
<
F16
,
F16
,
F16
,
F16
,
F16
,
2
,
1
>>>&
instances
)
{
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment