"...composable_kernel_rocm.git" did not exist on "94ae7113bd05e3c39364193dba1b391a4c54a2f4"
Unverified Commit 1ddc3ec7 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-756

parents e9703d5b f5ec04f0
......@@ -37,7 +37,8 @@ __global__ void
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
......
......@@ -116,7 +116,15 @@ struct Max
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Lowest();
}
};
__host__ __device__ static constexpr bool
......@@ -138,6 +146,15 @@ struct Max
a = b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -152,6 +169,18 @@ struct Max
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
};
struct Min
......@@ -159,6 +188,15 @@ struct Min
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Max();
}
return NumericLimits<T>::Max();
};
......@@ -181,6 +219,15 @@ struct Min
a = b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -195,6 +242,18 @@ struct Min
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
};
struct AMax
......
......@@ -53,7 +53,16 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
{
int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length)
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
{
if constexpr(is_same_v<ConputeDataType, bhalf_t>)
{
float buf_val = ck::type_convert<float>(buf[index]);
buf_val += ck::type_convert<float>(arg.dout_.mData[i]);
buf[index] = ck::type_convert<ConputeDataType>(buf_val);
}
else
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
}
}
for(int i = 0; i < din_length; ++i)
......
......@@ -256,10 +256,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
......
......@@ -103,6 +103,7 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_avgpool_bwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F16, F16, NDHWC, NDHWC>>>&);
#endif
#ifdef CK_ENABLE_BF16
void add_device_avgpool_bwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, BF16, BF16, NDHWC, NDHWC>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_avgpool_bwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F32, F32, NDHWC, NDHWC>>>&);
#endif
template <typename DOutDataType, typename DInDataType, typename InLayout, typename OutLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceAvgPoolBwd<3, DOutDataType, DInDataType, InLayout, OutLayout>>
{
using DeviceOp = DeviceAvgPoolBwd<3, DOutDataType, DInDataType, InLayout, OutLayout>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16>)
add_device_avgpool_bwd_ndhwc_f16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<DOutDataType, BF16> && is_same_v<DInDataType, BF16>)
add_device_avgpool_bwd_ndhwc_bf16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<DOutDataType, F32> && is_same_v<DInDataType, F32>)
add_device_avgpool_bwd_ndhwc_f32_instances(op_ptrs);
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// fp16_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>& instances);
// fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F32,
PassThrough,
PassThrough,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F32,
PassThrough,
PassThrough,
Add>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Row_Tuple,
ELayout,
ADataType,
BDataType,
F32_Tuple,
EDataType,
PassThrough,
PassThrough,
Add>>
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Row_Tuple,
ELayout,
ADataType,
BDataType,
F32_Tuple,
EDataType,
PassThrough,
PassThrough,
Add>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
}
// fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_maxpool_bwd_f16_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F16, I32, F16>>>&);
#endif
#ifdef CK_ENABLE_BF16
void add_device_maxpool_bwd_bf16_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<BF16, I32, BF16>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_maxpool_bwd_f32_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F32, I32, F32>>>&);
#endif
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>>
{
using DeviceOp = DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<DOutDataType, BF16> && is_same_v<DInDataType, BF16> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_bf16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<DOutDataType, F32> && is_same_v<DInDataType, F32> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f32_instances(op_ptrs);
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -37,6 +37,21 @@ void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void add_device_pool3d_fwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// BF16 - return index
void add_device_pool3d_fwd_ndhwc_index_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_pool3d_fwd_ndhwc_f32_instances(
......@@ -98,9 +113,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_bf16_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_bf16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
......
......@@ -26,7 +26,9 @@ struct DeviceMem
void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const;
void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const;
template <typename T>
void SetValue(T x) const;
......
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I32 = int32_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using device_avgpool_bwd_ndhwc_f16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_bf16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_f32_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, BF16, BF16, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_bf16_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F16, F16, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f16_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F32, F32, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -83,7 +83,6 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>
// clang-format on
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
......
add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment