Commit 6368be50 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 32806d5f 71d6ede7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
...@@ -14,8 +17,8 @@ ...@@ -14,8 +17,8 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ADataType = F16; using ADataType = F32;
using BDataType = F16; using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
...@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance = ...@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance =
2, // NumDim_m, {N, C} 2, // NumDim_m, {N, C}
2, // NumDim_n, {H, W} 2, // NumDim_n, {H, W}
1, // NumDim_k, {D} 1, // NumDim_k, {D}
8, // MPerThread 4, // MPerThread
8, // NPerThread 4, // NPerThread
8, // KPerThread 4, // KPerThread
ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<4>, // InScalarPerVectorSeq
ck::Sequence<4>>; // OutScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor> template <typename HostTensorA, typename HostTensorB, typename Functor>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
#include <random> #include <random>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
...@@ -174,6 +174,11 @@ struct PassThrough ...@@ -174,6 +174,11 @@ struct PassThrough
{ {
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif #endif
template <> template <>
......
...@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk ...@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
XDataType* const __restrict__ p_x_lds, XDataType* const __restrict__ p_x_lds_,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
...@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk ...@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
auto in_thread_buf_tuple = generate_tuple( auto in_thread_buf_tuple = generate_tuple(
[&](auto) { [&](auto) {
......
...@@ -328,7 +328,7 @@ struct WmmaSelector ...@@ -328,7 +328,7 @@ struct WmmaSelector
} }
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int, 16, 16>() static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&) ...@@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
return math::max(TupleDepth<depth + 1>(Ts{})...); return math::max(TupleDepth<depth + 1>(Ts{})...);
} }
template <index_t from, index_t to, typename... Ts>
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<from + i>;
return tuple.At(Idx{});
},
Number<to - from>{});
}
} // namespace ck } // namespace ck
...@@ -14,11 +14,9 @@ namespace wrapper { ...@@ -14,11 +14,9 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes * (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged. * (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t * \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
*/ */
template <typename Shape, typename Strides> template <typename Shape, typename UnnestedDescriptorType>
struct Layout struct Layout
{ {
private: private:
...@@ -31,7 +29,7 @@ struct Layout ...@@ -31,7 +29,7 @@ struct Layout
{ {
return generate_tuple( return generate_tuple(
[&](auto) { [&](auto) {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
{ {
// runtime layout // runtime layout
return index_t(0); return index_t(0);
...@@ -45,27 +43,6 @@ struct Layout ...@@ -45,27 +43,6 @@ struct Layout
Number<Tuple<Ts...>::Size()>{}); Number<Tuple<Ts...>::Size()>{});
} }
// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type // Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge) // If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element) // If tuple is element, then pass through (sequence with one element)
...@@ -207,33 +184,15 @@ struct Layout ...@@ -207,33 +184,15 @@ struct Layout
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
} }
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using FlattenDescriptorType =
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
using Descriptor1dType = using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>; remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>; using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
template <typename... ShapeDims, typename... IdxDims> template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto __host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape, TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx, const Tuple<IdxDims...>& idx,
const FlattenDescriptorType& naive_descriptor) const UnnestedDescriptorType& naive_descriptor)
{ {
if constexpr(Tuple<IdxDims...>::Size() == I1) if constexpr(Tuple<IdxDims...>::Size() == I1)
{ {
...@@ -256,48 +215,33 @@ struct Layout ...@@ -256,48 +215,33 @@ struct Layout
} }
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc( using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>; Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;
public: public:
__host__ __device__ constexpr auto GetElementSpaceSize() const __host__ __device__ constexpr auto GetElementSpaceSize() const
{ {
return flatten_descriptor_.GetElementSpaceSize(); return unnested_descriptor_.GetElementSpaceSize();
} }
__host__ __device__ Layout() = delete; __host__ __device__ Layout() = delete;
/** /**
* \brief Layout constructor. * \brief Layout constructor.
* *
* \param shape Shape for layout. * \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed). * \param unnested_descriptor Descriptor
*/ */
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides) __host__ __device__ constexpr Layout(const Shape& shape,
: flatten_descriptor_{}, shape_(shape), strides_(strides) const UnnestedDescriptorType& unnested_descriptor)
: shape_(shape)
{ {
// Construct if runtime mode // Construct if runtime mode
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__ __device__ constexpr Layout(const Shape& shape)
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
{ {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) unnested_descriptor_ = unnested_descriptor;
{ descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_);
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ = merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_); TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
} }
} }
...@@ -310,9 +254,9 @@ struct Layout ...@@ -310,9 +254,9 @@ struct Layout
template <typename Idxs> template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const __host__ __device__ constexpr index_t operator()() const
{ {
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(), static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout."); "Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{})); using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
} }
...@@ -339,7 +283,7 @@ struct Layout ...@@ -339,7 +283,7 @@ struct Layout
else else
{ {
// Custom index, need to transform descriptor // Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_); const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
} }
} }
...@@ -351,7 +295,7 @@ struct Layout ...@@ -351,7 +295,7 @@ struct Layout
* \return Calculated size. * \return Calculated size.
*/ */
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const __host__ __device__ constexpr auto GetLength() const
{ {
const auto elem = shape_.At(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
...@@ -371,7 +315,7 @@ struct Layout ...@@ -371,7 +315,7 @@ struct Layout
* *
* \return Calculated size. * \return Calculated size.
*/ */
__host__ __device__ constexpr index_t GetLengths() const __host__ __device__ constexpr auto GetLengths() const
{ {
const auto unrolled_shape = UnrollNestedTuple(shape_); const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
...@@ -385,13 +329,6 @@ struct Layout ...@@ -385,13 +329,6 @@ struct Layout
*/ */
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; } __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
/** /**
* \brief Get default lengths (tuple filled with Shape length elements). * \brief Get default lengths (tuple filled with Shape length elements).
* *
...@@ -417,17 +354,26 @@ struct Layout ...@@ -417,17 +354,26 @@ struct Layout
* *
* \return Default descriptor. * \return Default descriptor.
*/ */
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor() __host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
{ {
return merged_nests_descriptor_; return merged_nests_descriptor_;
} }
/**
* \brief Get unnested descriptor (with unrolled dims)
*
* \return Flatten descriptor.
*/
__host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
{
return unnested_descriptor_;
}
private: private:
FlattenDescriptorType flatten_descriptor_; UnnestedDescriptorType unnested_descriptor_;
Descriptor1dType descriptor_1d_; Descriptor1dType descriptor_1d_;
MergedNestsDescriptorType merged_nests_descriptor_; MergedNestsDescriptorType merged_nests_descriptor_;
const Shape shape_; const Shape shape_;
const DeducedStrides strides_;
}; };
} // namespace wrapper } // namespace wrapper
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "../utils/tensor_utils.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Perform generic copy between two tensors. Tensors must have the
* same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
if constexpr(!SrcTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(src_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else if constexpr(!DstTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(dst_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else
{
for(int i = 0; i < size(src_tensor); i++)
{
dst_tensor(i) = src_tensor(i);
}
}
}
} // namespace wrapper
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "utils/tensor_utils.hpp" #include "utils/tensor_utils.hpp"
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp" #include "utils/layout_utils.hpp"
namespace ck { namespace ck {
...@@ -15,14 +16,14 @@ namespace wrapper { ...@@ -15,14 +16,14 @@ namespace wrapper {
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type. * \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component). * \tparam Shape Tensor shape (layout component).
* \tparam Strides Tensor strides (layout component). * \tparam UnnestedDescriptorType Unnested descriptor (layout component).
* \tparam NumVectors Number of vectors (only for VGPR, SGPR). * \tparam NumVectors Number of vectors (only for VGPR, SGPR).
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR). * \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
*/ */
template <MemoryTypeEnum BufferAddressSpace, template <MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, // param for Register memory index_t NumVectors, // param for Register memory
index_t ScalarPerVector // param for Register memory index_t ScalarPerVector // param for Register memory
> >
...@@ -31,49 +32,19 @@ struct Tensor ...@@ -31,49 +32,19 @@ struct Tensor
private: private:
// Check if Tuple contains Slice object // Check if Tuple contains Slice object
template <typename T> template <typename T>
constexpr static bool IsSlicing(T&&) __host__ __device__ constexpr static bool IsSlicing(T&&)
{ {
return is_detected<is_slice, T>::value; return is_detected<is_slice, T>::value;
} }
template <typename... Ts> template <typename... Ts>
constexpr static bool IsSlicing(Tuple<Ts...>&&) __host__ __device__ constexpr static bool IsSlicing(Tuple<Ts...>&&)
{ {
return (IsSlicing(Ts{}) || ...); return (IsSlicing(Ts{}) || ...);
} }
// Calculate first index of new tensor after slice
// It is needed to calculate offset for new tensor
template <typename... Ts>
constexpr auto GetStartIdxForSlicedTensor(const Tuple<Ts...>& idx) const
{
const auto start_idx_for_sliced_tensor = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return GetStartIdxForSlicedTensor(idx.At(num_i));
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if slice, return the beginning of the interval
return idx.At(num_i).from_;
}
else
{
// if one dim selected
return idx.At(num_i);
}
},
Number<Tuple<Ts...>::Size()>{});
return start_idx_for_sliced_tensor;
}
// Calculate new tensor shape after slice // Calculate new tensor shape after slice
template <typename... Ts, typename ShapeTmpType> template <typename... Ts, typename ShapeTmpType>
constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx, __host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
const ShapeTmpType& shape) const const ShapeTmpType& shape) const
{ {
// Pack each value in tuple to remove empty tuples after generation // Pack each value in tuple to remove empty tuples after generation
...@@ -112,48 +83,116 @@ struct Tensor ...@@ -112,48 +83,116 @@ struct Tensor
return UnrollNestedTuple<0, 1>(new_shape); return UnrollNestedTuple<0, 1>(new_shape);
} }
template <typename... Ts, typename StridesTmpType> // Generate Freeze for each of nested shape
constexpr auto GetStridesFromSlicedTensor(const Tuple<Ts...>& idx, template <typename T, typename ShapeTmpType>
const StridesTmpType& strides) const __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx,
const ShapeTmpType& shape) const
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
// dimension offset from idx
const auto dim = unrolled_shape.At(Number<i>{});
const auto dim_idx = idx % dim;
idx /= dim;
return make_freeze_transform(dim_idx);
},
Number<decltype(unrolled_shape)::Size()>{});
}
template <typename... Ts, typename ShapeTmpType>
__host__ __device__ constexpr auto
GetTransformsFromSlicedTensor(const Tuple<Ts...>& idx, const ShapeTmpType& shape) const
{ {
// Pack each value in tuple to remove empty tuples after generation // Pack each value in tuple to remove empty tuples after generation
auto new_strides = generate_tuple( auto transforms = generate_tuple(
[&](auto i) { [&](auto i) {
constexpr auto num_i = Number<i>{}; constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{ {
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{})) return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i));
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
}
else
{
// if tuple then recurrence
return make_tuple(
GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i)));
}
} }
else if constexpr(is_detected<is_slice, else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value) tuple_element_t<i.value, Tuple<Ts...>>>::value)
{ {
// Stride will be the same
return make_tuple(strides.At(num_i)); const auto from = idx.At(num_i).from_;
const auto dim = shape.At(num_i);
const auto range = idx.At(num_i).range(dim);
return make_slice_transform(range, from, from + range);
} }
else else
{ {
// remove dimension for just value // remove dimension for just value
return Tuple<>{}; return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
} }
}, },
Number<Tuple<Ts...>::Size()>{}); Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return // Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_strides); return UnrollNestedTuple(transforms);
}
// There is no output for Freeze transform
template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&) const
{
return Sequence<>{};
}
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto
GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&) const
{
return Sequence<i>{};
}
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const
{
return Tuple<>{};
}
template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto
GenerateUpperDims(const Tuple<Transforms...>& transforms) const
{
constexpr auto num_transforms = Tuple<Transforms...>::Size();
// Deduce Sequence element for specific transform
const auto currect_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
if constexpr(is_same_v<decltype(currect_elem), const Sequence<>>)
{
const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(currect_elem), next_tuple);
}
else
{
// Increase i if current_elem is Slice transform
const auto next_tuple =
GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(currect_elem), next_tuple);
}
}
template <typename... Ts, typename ShapeTmpType, typename FlattenDescriptor>
__host__ __device__ constexpr auto
GetDescriptorFromSlicedTensor(const Tuple<Ts...>& idx,
const ShapeTmpType& shape,
const FlattenDescriptor& flatten_desc) const
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
const auto transforms = GetTransformsFromSlicedTensor(idx, shape);
using TransformsTupleType = decltype(transforms);
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
} }
public: public:
using ElementSpaceSize = decltype(Layout<Shape, Strides>{ using ElementSpaceSize = decltype(Layout<Shape, UnnestedDescriptorType>{
Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = ElementType; // DataType using TensorElementType = ElementType; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
...@@ -161,18 +200,20 @@ struct Tensor ...@@ -161,18 +200,20 @@ struct Tensor
BufferAddressSpace == MemoryTypeEnum ::Vgpr); BufferAddressSpace == MemoryTypeEnum ::Vgpr);
__host__ __device__ Tensor() = delete; __host__ __device__ Tensor() = delete;
__host__ __device__ Tensor(ElementType* pointer, const Layout<Shape, Strides>& layout) __host__ __device__ Tensor(ElementType* pointer,
const Layout<Shape, UnnestedDescriptorType>& layout)
: layout_(layout), : layout_(layout),
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())) buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize()))
{ {
} }
__host__ __device__ Tensor(const Layout<Shape, Strides>& layout) : layout_(layout) __host__ __device__ Tensor(const Layout<Shape, UnnestedDescriptorType>& layout)
: layout_(layout)
{ {
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
} }
__host__ __device__ constexpr const Layout<Shape, Strides>& GetLayout() const __host__ __device__ constexpr const Layout<Shape, UnnestedDescriptorType>& GetLayout() const
{ {
return layout_; return layout_;
} }
...@@ -182,21 +223,14 @@ struct Tensor ...@@ -182,21 +223,14 @@ struct Tensor
__host__ __device__ auto operator[](const Tuple<Ts...>& idx) const __host__ __device__ auto operator[](const Tuple<Ts...>& idx) const
{ {
static_assert(IsDynamicBuffer, "Register slice is not supported"); static_assert(IsDynamicBuffer, "Register slice is not supported");
// Calculate offset based on first idx for new tensor const auto& shape = layout_.GetShape();
const index_t offset = layout_(GetStartIdxForSlicedTensor(idx)); auto new_shape = GetShapeFromSlicedTensor(idx, shape);
auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape()); const auto& flatten_desc = layout_.GetUnnestedDescriptor();
if constexpr(is_same_v<Strides, Tuple<>>) auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc);
{ const auto new_layout =
auto new_layout = make_layout(new_shape); Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout); return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
}
else
{
auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides());
auto new_layout = make_layout(new_shape, new_strides);
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
}
} }
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
...@@ -222,19 +256,11 @@ struct Tensor ...@@ -222,19 +256,11 @@ struct Tensor
} }
else else
{ {
if constexpr(is_same_v<Strides, Tuple<>>) constexpr index_t offset = Layout<Shape, UnnestedDescriptorType>{
{ Shape{},
constexpr index_t offset = UnnestedDescriptorType{}}.template operator()<Tuple<Ts...>>();
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}]; return buffer_[Number<offset>{}];
} }
else
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}];
}
}
} }
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
...@@ -260,20 +286,12 @@ struct Tensor ...@@ -260,20 +286,12 @@ struct Tensor
} }
else else
{ {
if constexpr(is_same_v<Strides, Tuple<>>) constexpr index_t offset = Layout<Shape, UnnestedDescriptorType>{
{ Shape{},
constexpr index_t offset = UnnestedDescriptorType{}}.template operator()<Tuple<Ts...>>();
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{});
}
else
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{}); return buffer_(Number<offset>{});
} }
} }
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false> template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx) __host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
...@@ -292,6 +310,8 @@ struct Tensor ...@@ -292,6 +310,8 @@ struct Tensor
return layout_.GetDefaultDescriptor(); return layout_.GetDefaultDescriptor();
} }
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
private: private:
using DynamicBufferType = DynamicBuffer<BufferAddressSpace, using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
ElementType, ElementType,
...@@ -306,7 +326,7 @@ struct Tensor ...@@ -306,7 +326,7 @@ struct Tensor
// If register use static buffer, else use dynamic buffer // If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>; using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
const Layout<Shape, Strides> layout_; const Layout<Shape, UnnestedDescriptorType> layout_;
Buffer buffer_; Buffer buffer_;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,11 +22,57 @@ namespace wrapper { ...@@ -22,11 +22,57 @@ namespace wrapper {
// Disable from doxygen docs generation // Disable from doxygen docs generation
/// @cond /// @cond
// forward declaration // forward declaration
template <typename Shape, typename Strides> template <typename Shape, typename UnnestedDescriptorType>
struct Layout; struct Layout;
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return Number<1>{};
}
else
{
return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
{
// if not passed, then generate
const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
else
{
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
} // namespace
/// @endcond /// @endcond
// make_* // make_*
...@@ -38,10 +84,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple()); ...@@ -38,10 +84,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \return Constructed layout. * \return Constructed layout.
*/ */
template <typename Shape, typename Strides> template <typename Shape, typename Strides>
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape, __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
const Strides& strides)
{ {
return Layout<Shape, Strides>(shape, strides); using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
} }
/** /**
...@@ -52,9 +98,10 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh ...@@ -52,9 +98,10 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
* \return Constructed layout. * \return Constructed layout.
*/ */
template <typename Shape> template <typename Shape>
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape) __host__ __device__ constexpr auto make_layout(const Shape& shape)
{ {
return Layout<Shape, Tuple<>>(shape); using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
} }
// Layout helpers // Layout helpers
...@@ -89,26 +136,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple) ...@@ -89,26 +136,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout. * \param layout Layout to create sub layout.
* \return Requsted sub layout. * \return Requsted sub layout.
*/ */
template <index_t idx, typename Shape, typename Strides> template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout) __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
{ {
const auto& shape = layout.GetShape(); const auto& shape = layout.GetShape();
const auto& new_shape = get<idx>(shape); const auto new_shape = get<idx>(shape);
static_assert(is_detected<is_tuple, decltype(new_shape)>::value, static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
"Shape of sub layout must be tuple"); "Shape of sub layout must be tuple");
if constexpr(is_same_v<Strides, Tuple<>>)
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto transforms = generate_tuple(
[&](auto i) {
// Compare Idx with shape
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
{
// Remove dimension
return make_freeze_transform(Number<0>{});
}
else
{ {
// If stride not passed, create without strides return make_pass_through_transform(unrolled_shape.At(i));
return make_layout(new_shape);
} }
},
Number<old_shape_dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
return Sequence<>{};
else else
{ {
const auto& strides = layout.GetStrides(); return Sequence<i.value - shape_offset>{};
const auto& new_strides = get<idx>(strides);
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
"Strides of sub layout must be tuple");
return make_layout(new_shape, new_strides);
} }
},
Number<old_shape_dims>{});
const auto& flatten_desc = layout.GetUnnestedDescriptor();
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
} }
/** /**
...@@ -142,8 +214,8 @@ __host__ __device__ T constexpr size(const T& dim) ...@@ -142,8 +214,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of. * \param layout Layout to get Shape of.
* \return Requsted length. * \return Requsted length.
*/ */
template <index_t idx, typename Shape, typename Strides> template <index_t idx, typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout) __host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
{ {
return layout.template GetLength<idx>(); return layout.template GetLength<idx>();
} }
...@@ -155,7 +227,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout) ...@@ -155,7 +227,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
* \return Requsted size. * \return Requsted size.
*/ */
template <typename... ShapeDims> template <typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape) __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
{ {
const auto unrolled_shape = UnrollNestedTuple(shape); const auto unrolled_shape = UnrollNestedTuple(shape);
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
...@@ -168,8 +240,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape) ...@@ -168,8 +240,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size. * \param layout Layout to calculate shape size.
* \return Requsted size. * \return Requsted size.
*/ */
template <typename Shape, typename Strides> template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout) __host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
{ {
return layout.GetLengths(); return layout.GetLengths();
} }
...@@ -182,7 +254,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout) ...@@ -182,7 +254,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
* \return Requsted length. * \return Requsted length.
*/ */
template <index_t idx, typename... Ts> template <index_t idx, typename... Ts>
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple) __host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
{ {
return size(tuple.At(Number<idx>{})); return size(tuple.At(Number<idx>{}));
} }
...@@ -208,8 +280,9 @@ __host__ __device__ constexpr auto size(const T& elem) ...@@ -208,8 +280,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank. * \param layout Layout to calculate rank.
* \return Requsted rank. * \return Requsted rank.
*/ */
template <typename Shape, typename Strides> template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout) __host__ __device__ constexpr auto
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
{ {
return Shape::Size(); return Shape::Size();
} }
...@@ -261,8 +334,8 @@ __host__ __device__ constexpr auto rank(const T& elem) ...@@ -261,8 +334,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth. * \param layout Layout to calculate depth.
* \return Requsted depth. * \return Requsted depth.
*/ */
template <typename Shape, typename Strides> template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout) __host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
{ {
const auto& shape = layout.GetShape(); const auto& shape = layout.GetShape();
return TupleDepth(shape); return TupleDepth(shape);
...@@ -307,26 +380,14 @@ __host__ __device__ constexpr auto depth(const T& elem) ...@@ -307,26 +380,14 @@ __host__ __device__ constexpr auto depth(const T& elem)
return depth(get<Idxs...>(elem)); return depth(get<Idxs...>(elem));
} }
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides from.
* \return Requsted strides.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
{
return layout.GetStrides();
}
/** /**
* \brief Get Layout shape. * \brief Get Layout shape.
* *
* \param layout Layout to get shape from. * \param layout Layout to get shape from.
* \return Requsted shape. * \return Requsted shape.
*/ */
template <typename Shape, typename Strides> template <typename LayoutType>
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout) __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
{ {
return layout.GetShape(); return layout.GetShape();
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
namespace ck {
namespace wrapper {
namespace {
// Calculate shape for partition based on number of threads per each dim and
// previous shape
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i));
}
else
{
const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i);
return slice_len;
}
},
Number<Tuple<Ts...>::Size()>{});
}
// Calculate shape for partition based on number of threads per each dim,
// previous strides and steps
template <typename... Ts, typename... Ls, typename... Steps, typename FlattenDescType>
__host__ __device__ constexpr auto
CalculateLocalPartitionDescriptor(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
const FlattenDescType& flatten_desc)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths);
const auto unrolled_shape = UnrollNestedTuple(shape);
constexpr auto dims = decltype(unrolled_thread_lengths)::Size();
using UnrolledStepsType = decltype(UnrollNestedTuple(steps));
using I1 = Number<1>;
const auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
// By default raked partition
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else if constexpr(!is_same_v<tuple_element_t<i.value, UnrolledStepsType>, index_t>)
{
// Compiletime partition
if constexpr(is_same_v<tuple_element_t<i.value, UnrolledStepsType>, I1>)
{
// raked
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else
{
// packed
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(I1{}));
}
}
else
{
// Runtime partition
if(steps.At(num_i) == 1)
{
// raked
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else
{
// packed
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(I1{}));
}
}
},
Number<dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
const auto upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
template <typename... Ls, typename... Steps>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
index_t& thread_id)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ls...>>>::value)
{
// if tuple then recurrence
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return CalculateLayoutOffsetIdxImpl(
thread_lengths.At(num_i), Tuple<>{}, thread_id);
}
else
{
return CalculateLayoutOffsetIdxImpl(
thread_lengths.At(num_i), steps.At(num_i), thread_id);
}
}
else
{
// Update thread_id after each dim
const auto dim_thread_id = thread_id % thread_lengths.At(num_i);
thread_id /= thread_lengths.At(num_i);
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return dim_thread_id;
}
else
{
// Apply step
return steps.At(num_i) * dim_thread_id;
}
}
},
Number<Tuple<Ls...>::Size()>{});
}
// Convert integer thread_idx to tuple index with steps applied
template <typename... Ls, typename... Steps>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
const index_t thread_id)
{
// Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates
index_t thread_id_copy = thread_id;
return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy);
}
// Apply steps to index represented as tuple
template <typename... Steps, typename... Idxs>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Steps...>& steps,
const Tuple<Idxs...>& block_idxs)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Idxs...>>>::value)
{
// if tuple then recurrence
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i));
}
else
{
return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i));
}
}
else
{
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return block_idxs.At(num_i);
}
else
{
// apply step
return steps.At(num_i) * block_idxs.At(num_i);
}
}
},
Number<Tuple<Idxs...>::Size()>{});
}
// User passes only shape per block to the make_local_tile function. This function calculates
// block layout based on the shape.
template <typename... Ts, typename... BlockDims>
__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple<Ts...>& shape,
const Tuple<BlockDims...>& tile_shape)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i));
}
else
{
return shape.At(num_i) / tile_shape.At(num_i);
}
},
Number<Tuple<Ts...>::Size()>{});
}
} // namespace
/**
* \brief Create local partition for thread.
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads.
* \param thread_id Thread index represented as integer.
* \param steps Thread step (default=1, raked partition)
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename StepsTuple = Tuple<>>
__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id,
const StepsTuple steps = StepsTuple{})
{
// Create shape, strides and layout for new partition tensor
const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths);
// Create new descriptor and layout
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
auto partition_desc =
CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc);
const auto partition_layout = Layout<decltype(partition_shape), decltype(partition_desc)>(
partition_shape, partition_desc);
// Calculate offset for new partition tensor
const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id);
const auto partition_offset = layout(tensor)(offset_idx);
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + partition_offset,
partition_layout);
}
/**
* \brief Create local tile for thread block.
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_idx Block index represented as tuple.
* \param steps Block step (default=1, raked partition)
* \return Tile tensor.
*/
template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxTuple,
typename StepsTuple = Tuple<>>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const BlockIdxTuple& block_idx,
const StepsTuple steps = StepsTuple{})
{
// Create block lengths, strides and layout for new tile tensor
const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape);
// Create new descriptor and layout
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
auto tile_desc =
CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc);
const auto tile_layout = Layout<remove_reference_t<decltype(tile_shape)>, decltype(tile_desc)>(
tile_shape, tile_desc);
// Calculate offset for new partition tensor
const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx);
const auto tile_offset = layout(tensor)(offset_idx);
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + tile_offset,
tile_layout);
}
} // namespace wrapper
} // namespace ck
...@@ -27,12 +27,12 @@ using MemoryTypeEnum = AddressSpaceEnum; ...@@ -27,12 +27,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation // Disable from doxygen docs generation
/// @cond /// @cond
// forward declarations // forward declarations
template <typename Shape, typename Strides> template <typename Shape, typename UnnestedDescriptorType>
struct Layout; struct Layout;
template <MemoryTypeEnum BufferAddressSpace, template <MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, // params for Register memory index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory index_t ScalarPerVector // param for Register memory
> >
...@@ -98,11 +98,19 @@ using is_tuple = decltype(std::declval<T&>().IsTuple()); ...@@ -98,11 +98,19 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \param layout Tensor layout. * \param layout Tensor layout.
* \return Constructed tensor. * \return Constructed tensor.
*/ */
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides> template <MemoryTypeEnum MemoryType,
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout) typename ElementType,
typename Shape,
typename UnnestedDescriptorType>
constexpr auto make_tensor(ElementType* pointer,
const Layout<Shape, UnnestedDescriptorType>& layout)
{ {
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>( return Tensor<MemoryType,
pointer, layout); ElementType,
Shape,
UnnestedDescriptorType,
0 /*NumVectors*/,
0 /*ScalarPerVector*/>(pointer, layout);
} }
/** /**
...@@ -112,19 +120,21 @@ constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& l ...@@ -112,19 +120,21 @@ constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& l
* \tparam NumVectors Number of vectors. * \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector. * \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type. * \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor. * \return Constructed tensor.
*/ */
template <MemoryTypeEnum MemoryType, template <MemoryTypeEnum MemoryType,
index_t NumVectors, index_t NumVectors,
index_t ScalarPerVector, index_t ScalarPerVector,
typename ElementType, typename ElementType>
typename Shape, constexpr auto make_register_tensor()
typename Strides>
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
{ {
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported"); const auto layout = make_layout(make_tuple(Number<NumVectors>{}), make_tuple(Number<1>{}));
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout); return Tensor<MemoryType,
ElementType,
Tuple<Number<NumVectors>>,
std::remove_const_t<remove_reference_t<decltype(layout.GetUnnestedDescriptor())>>,
NumVectors,
ScalarPerVector>(layout);
} }
/** /**
...@@ -136,12 +146,15 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout) ...@@ -136,12 +146,15 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
template <MemoryTypeEnum BufferAddressSpace, template <MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, index_t NumVectors,
index_t ScalarPerVector> index_t ScalarPerVector>
__host__ __device__ constexpr const auto& __host__ __device__ constexpr const auto& layout(const Tensor<BufferAddressSpace,
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>& ElementType,
tensor) Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{ {
return tensor.GetLayout(); return tensor.GetLayout();
} }
...@@ -157,12 +170,15 @@ template <index_t... Idxs, ...@@ -157,12 +170,15 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace, MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, index_t NumVectors,
index_t ScalarPerVector> index_t ScalarPerVector>
__host__ __device__ constexpr index_t __host__ __device__ constexpr auto size(const Tensor<BufferAddressSpace,
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>& ElementType,
tensor) Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{ {
return size<Idxs...>(tensor.GetLayout()); return size<Idxs...>(tensor.GetLayout());
} }
...@@ -178,12 +194,15 @@ template <index_t... Idxs, ...@@ -178,12 +194,15 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace, MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, index_t NumVectors,
index_t ScalarPerVector> index_t ScalarPerVector>
__host__ __device__ constexpr index_t __host__ __device__ constexpr auto rank(const Tensor<BufferAddressSpace,
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>& ElementType,
tensor) Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{ {
return rank<Idxs...>(tensor.GetLayout()); return rank<Idxs...>(tensor.GetLayout());
} }
...@@ -199,35 +218,19 @@ template <index_t... Idxs, ...@@ -199,35 +218,19 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace, MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, index_t NumVectors,
index_t ScalarPerVector> index_t ScalarPerVector>
__host__ __device__ constexpr index_t __host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>& ElementType,
tensor) Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{ {
return depth<Idxs...>(tensor.GetLayout()); return depth<Idxs...>(tensor.GetLayout());
} }
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return stride(tensor.GetLayout());
}
/** /**
* \brief Get Tensor shape. * \brief Get Tensor shape.
* *
...@@ -237,12 +240,15 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ...@@ -237,12 +240,15 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors,
template <MemoryTypeEnum BufferAddressSpace, template <MemoryTypeEnum BufferAddressSpace,
typename ElementType, typename ElementType,
typename Shape, typename Shape,
typename Strides, typename UnnestedDescriptorType,
index_t NumVectors, index_t NumVectors,
index_t ScalarPerVector> index_t ScalarPerVector>
__host__ __device__ constexpr const auto& __host__ __device__ constexpr const auto& shape(const Tensor<BufferAddressSpace,
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>& ElementType,
tensor) Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{ {
return shape(tensor.GetLayout()); return shape(tensor.GetLayout());
} }
......
...@@ -23,20 +23,19 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; ...@@ -23,20 +23,19 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, NDHWGC,
// NDHWGC, GKZYXC,
// GKZYXC, ck::Tuple<>,
// ck::Tuple<>, NDHWGK,
// NDHWGK, ck::Tuple<BF16, BF16>,
// ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>,
// ck::Tuple<BF16, BF16>, ck::Tuple<>,
// ck::Tuple<>, BF16,
// BF16, ScaleAdd,
// ScaleAdd, ScaleAdd,
// ScaleAdd, PassThrough>>>& instances);
// PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
...@@ -152,15 +151,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -152,15 +151,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 if constexpr(is_same_v<InDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
// if constexpr(is_same_v<InDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> && is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
// is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> && is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, ck::bhalf_t>)
// is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, {
// ck::bhalf_t>) add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// { op_ptrs);
// add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( }
// op_ptrs);
// }
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, ck::Tuple<int8_t, int8_t>> && if constexpr(is_same_v<InDataType, ck::Tuple<int8_t, int8_t>> &&
......
...@@ -21,20 +21,19 @@ template <ck::index_t... Is> ...@@ -21,20 +21,19 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using device_transpose_f16_instances = std::tuple< using device_transpose_f16_instances = std::tuple<
// FOR 16, 32, 16, 32, 16
// clang-format off // clang-format off
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 4, 4, 8, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>
// clang-format on // clang-format on
>; >;
using device_transpose_f32_instances = std::tuple< using device_transpose_f32_instances = std::tuple<
// for 16, 8, 16, 32, 8 -> test with instances for fp16
// clang-format off // clang-format off
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 4, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<4>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<4>, ck::Sequence<4>>
// clang-format on // clang-format on
>; >;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment