Commit 2fd6c6d4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents c32d3448 6651a124
......@@ -22,14 +22,19 @@ namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
// Generate packed (column-major) strides if not passed
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
......@@ -50,9 +55,16 @@ GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
......@@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
}
/**
......@@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
......@@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
},
Number<old_shape_dims>{});
const auto& flatten_desc = layout.GetUnnestedDescriptor();
const auto& flatten_desc = layout.GetUnrolledDescriptor();
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
}
......@@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
}
// size
// Get dim size (could be returned from get function)
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
......@@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template <index_t idx, typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.template GetLength<idx>();
}
......@@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.GetLengths();
}
......@@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Shape::Size();
}
......@@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
{
return 1;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
/**
* \brief Hierarchical rank.
......@@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
{
const auto& shape = layout.GetShape();
return TupleDepth(shape);
......@@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
{
return 0;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
/**
* \brief Hierarchical depth.
......
......@@ -6,12 +6,22 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
namespace ck {
namespace wrapper {
namespace {
// Calculate shape for partition based on number of threads per each dim and
// previous shape
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
*
* \param shape Base tensor shape.
* \param thread_lengths Tuple of thread lengths.
* \return Partition shape.
*/
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths)
......@@ -20,265 +30,165 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i));
}
else
{
const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i);
return slice_len;
}
},
Number<Tuple<Ts...>::Size()>{});
}
// Calculate shape for partition based on number of threads per each dim,
// previous strides and steps
template <typename... Ts, typename... Ls, typename... Steps, typename FlattenDescType>
__host__ __device__ constexpr auto
CalculateLocalPartitionDescriptor(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
const FlattenDescType& flatten_desc)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths);
const auto unrolled_shape = UnrollNestedTuple(shape);
constexpr auto dims = decltype(unrolled_thread_lengths)::Size();
using UnrolledStepsType = decltype(UnrollNestedTuple(steps));
using I1 = Number<1>;
const auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
// By default raked partition
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else if constexpr(!is_same_v<tuple_element_t<i.value, UnrolledStepsType>, index_t>)
{
// Compiletime partition
if constexpr(is_same_v<tuple_element_t<i.value, UnrolledStepsType>, I1>)
{
// raked
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else
{
// packed
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(I1{}));
}
}
else
{
// Runtime partition
if(steps.At(num_i) == 1)
{
// raked
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else
{
// packed
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(I1{}));
}
}
},
Number<dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
const auto upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
template <typename... Ls, typename... Steps>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
index_t& thread_id)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ls...>>>::value)
{
// if tuple then recurrence
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return CalculateLayoutOffsetIdxImpl(
thread_lengths.At(num_i), Tuple<>{}, thread_id);
}
else
{
return CalculateLayoutOffsetIdxImpl(
thread_lengths.At(num_i), steps.At(num_i), thread_id);
}
}
else
{
// Update thread_id after each dim
const auto dim_thread_id = thread_id % thread_lengths.At(num_i);
thread_id /= thread_lengths.At(num_i);
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return dim_thread_id;
}
else
{
// Apply step
return steps.At(num_i) * dim_thread_id;
}
}
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i);
return slice_len;
},
Number<Tuple<Ls...>::Size()>{});
}
// Convert integer thread_idx to tuple index with steps applied
template <typename... Ls, typename... Steps>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
const index_t thread_id)
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape)
{
// Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates
index_t thread_id_copy = thread_id;
return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy);
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); },
Number<Tuple<Ls...>::Size()>{});
}
// Apply steps to index represented as tuple
template <typename... Steps, typename... Idxs>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Steps...>& steps,
const Tuple<Idxs...>& block_idxs)
/**
* \brief Calculate scaled offset for new partition/tile.
*
* \param thread_idxs Thread 1d id.
* \param partition_lengths_seq Sequence of partition shape.
* \param old_offset_idxs Multi index offset from base tensor to shift values.
* \return Partition shape.
*/
template <typename ThreadIdxs, typename PartitionLengthsSeq, typename OldOffsetIdxs>
__host__ __device__ constexpr auto
CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
const PartitionLengthsSeq& partition_lengths_seq,
const OldOffsetIdxs& old_offset_idxs)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Idxs...>>>::value)
{
// if tuple then recurrence
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i));
}
else
{
return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i));
}
}
else
{
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return block_idxs.At(num_i);
}
else
{
// apply step
return steps.At(num_i) * block_idxs.At(num_i);
}
}
},
Number<Tuple<Idxs...>::Size()>{});
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
// User passes only shape per block to the make_local_tile function. This function calculates
// block layout based on the shape.
template <typename... Ts, typename... BlockDims>
__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple<Ts...>& shape,
const Tuple<BlockDims...>& tile_shape)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i));
}
else
{
return shape.At(num_i) / tile_shape.At(num_i);
}
},
Number<Tuple<Ts...>::Size()>{});
}
} // namespace
/**
* \brief Create local partition for thread.
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \param steps Thread step (default=1, raked partition)
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename StepsTuple = Tuple<>>
__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id,
const StepsTuple steps = StepsTuple{})
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
{
// Create shape, strides and layout for new partition tensor
const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths);
// Create new descriptor and layout
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
auto partition_desc =
CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc);
const auto partition_layout = Layout<decltype(partition_shape), decltype(partition_desc)>(
partition_shape, partition_desc);
// Calculate offset for new partition tensor
const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id);
const auto partition_offset = layout(tensor)(offset_idx);
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + partition_offset,
partition_layout);
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
constexpr auto partition_shape =
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{});
// Create Thread Cluster Descriptor
constexpr auto partition_lengths_seq = generate_sequence_v2(
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor();
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>(
partition_shape, flatten_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
/**
* \brief Create local tile for thread block.
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Temporary to gain the best performance use 2d
* tile_shape.
*
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_idx Block index represented as tuple.
* \param steps Block step (default=1, raked partition)
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxTuple,
typename StepsTuple = Tuple<>>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const BlockIdxTuple& block_idx,
const StepsTuple steps = StepsTuple{})
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
{
// Create block lengths, strides and layout for new tile tensor
const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape);
// Create new descriptor and layout
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
auto tile_desc =
CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc);
const auto tile_layout = Layout<remove_reference_t<decltype(tile_shape)>, decltype(tile_desc)>(
tile_shape, tile_desc);
// Calculate offset for new partition tensor
const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx);
const auto tile_offset = layout(tensor)(offset_idx);
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + tile_offset,
tile_layout);
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
if constexpr(BlockShapeTuple::Size() == I2)
{
// Optimized version for 2d tile shape [MxK]
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
BlockShapeTuple{}.At(I1),
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
const auto offset_multi_idxs =
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return tile_tensor;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto tile_shape_seq =
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); },
Number<BlockShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape);
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return tile_tensor;
}
}
} // namespace wrapper
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -10,6 +10,7 @@
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/multi_index.hpp"
namespace ck {
namespace wrapper {
......@@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory
>
typename UnrolledDescriptorType>
struct Tensor;
template <typename FromType, typename ToType>
......@@ -45,13 +42,22 @@ struct Slice
__host__ __device__ constexpr Slice() : from_(), to_() {}
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
/**
* \brief Calculate slice range.
*
* \param dim Dimension size.
* \return Slice range.
*/
template <typename T>
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
{
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
{
throw std::runtime_error("Invalid range");
}
if(to_ < 0)
{
return dim - from_ + to_ + 1;
......@@ -101,40 +107,27 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
template <MemoryTypeEnum MemoryType,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType>
typename UnrolledDescriptorType>
constexpr auto make_tensor(ElementType* pointer,
const Layout<Shape, UnnestedDescriptorType>& layout)
const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Tensor<MemoryType,
ElementType,
Shape,
UnnestedDescriptorType,
0 /*NumVectors*/,
0 /*ScalarPerVector*/>(pointer, layout);
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(pointer, layout);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType,
index_t NumVectors,
index_t ScalarPerVector,
typename ElementType>
constexpr auto make_register_tensor()
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
{
const auto layout = make_layout(make_tuple(Number<NumVectors>{}), make_tuple(Number<1>{}));
return Tensor<MemoryType,
ElementType,
Tuple<Number<NumVectors>>,
std::remove_const_t<remove_reference_t<decltype(layout.GetUnnestedDescriptor())>>,
NumVectors,
ScalarPerVector>(layout);
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
}
/**
......@@ -146,15 +139,9 @@ constexpr auto make_register_tensor()
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto& layout(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr const auto&
layout(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return tensor.GetLayout();
}
......@@ -170,15 +157,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr auto size(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
size(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return size<Idxs...>(tensor.GetLayout());
}
......@@ -194,15 +175,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr auto rank(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return rank<Idxs...>(tensor.GetLayout());
}
......@@ -218,15 +193,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
depth(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return depth<Idxs...>(tensor.GetLayout());
}
......@@ -240,15 +209,9 @@ __host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto& shape(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr const auto&
shape(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return shape(tensor.GetLayout());
}
......
......@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return 0;
}
throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return 0;
}
throw std::runtime_error(
"Conv_bwd_data: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return 0;
}
throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator
return 0;
}
throw std::runtime_error("Conv_fwd: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
ComputeTypeA v_a = 0;
ComputeTypeB v_b = 0;
for(int k = 0; k < K; ++k)
{
ComputeTypeA v_a;
ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
......@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c;
CDataType v_c = 0;
arg.c_element_op_(v_c, v_acc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck {
namespace tensor_operation {
......@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
return 0;
}
throw std::runtime_error("Img2Col: number of dimensions should be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,
......
......@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory<
return op_ptrs;
}
};
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
return op_ptrs;
}
};
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// FP32
void add_device_groupnorm_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 5, 3>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
5,
3>>
{
using DeviceOp = DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
5,
3>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<MeanInvStdDataType, F32> && is_same_v<DGammaDataType, F32> &&
is_same_v<DBetaDataType, F32>)
{
add_device_groupnorm_bwd_gamma_beta_f32_instances(op_ptrs);
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
// FP16
void add_device_layernorm2d_bwd_gamma_beta_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_layernorm2d_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 2, 1>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DYDataType, F16> && is_same_v<XDataType, F16> &&
is_same_v<MeanInvStdDataType, F16> && is_same_v<DGammaDataType, F16> &&
is_same_v<DBetaDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_layernorm2d_bwd_gamma_beta_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<MeanInvStdDataType, F32> && is_same_v<DGammaDataType, F32> &&
is_same_v<DBetaDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_layernorm2d_bwd_gamma_beta_f32_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......@@ -110,17 +111,39 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format on
>;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>
// clang-format on
>;
......@@ -141,9 +164,51 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<GemmMNKPadding>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
}
} // namespace instance
......
......@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......@@ -95,6 +96,41 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>
// clang-format on
>;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
......@@ -112,6 +148,52 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances<GemmMNKPadding>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
}
} // namespace instance
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances(
void add_device_layernorm2d_bwd_gamma_beta_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment