Unverified Commit aba0880d authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-850

parents c3d720cc d52ec016
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f32_instances(
add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Pass, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Pass, 5, 3>{});
}
} // namespace instance
......
......@@ -18,6 +18,8 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
instances, device_normalization_f16_f32_f32_f16_generic_instance<Swish, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{});
add_device_operation_instances(
instances, device_normalization_splitk_f16_f32_f32_f16_instances<Swish, 5, 3>{});
}
} // namespace instance
......
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Swish, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Swish, 5, 3>{});
}
} // namespace instance
......
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Swish, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Swish, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Swish, 5, 3>{});
}
} // namespace instance
......
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f16_instances(
add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Pass, 2, 1>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 2, 1>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Pass, 2, 1>{});
}
} // namespace instance
......
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f32_instances(
add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Pass, 2, 1>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 2, 1>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Pass, 2, 1>{});
}
} // namespace instance
......
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f16_instances(
add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Pass, 4, 3>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 4, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Pass, 4, 3>{});
}
} // namespace instance
......
......@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f32_instances(
add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Pass, 4, 3>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 4, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Pass, 4, 3>{});
}
} // namespace instance
......
......@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -43,6 +44,32 @@ using device_normalization_f16_instances =
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple<
// clang-format off
......@@ -76,6 +103,32 @@ using device_normalization_f32_instances = std::tuple<
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_generic_instance = std::tuple<
// clang-format off
......@@ -109,6 +162,32 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
// clang-format off
......
set(DEVICE_POOL_FWD_INSTANCES)
set(DEVICE_POOL3D_FWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_POOL_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.cpp
device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool2d_fwd_nhwc_f16_instance.cpp
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_POOL_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f32_instance.cpp
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool2d_fwd_nhwc_f32_instance.cpp
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_pool_fwd_instance ${DEVICE_POOL_FWD_INSTANCES})
add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES})
......@@ -11,7 +11,9 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
std::vector<
std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
......
......@@ -11,7 +11,9 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
std::vector<
std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
......
......@@ -11,14 +11,18 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
std::vector<
std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
}
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, true>>>& instances)
std::vector<
std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, NDHWC, NDHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
......
......@@ -11,14 +11,18 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
std::vector<
std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
}
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, true>>>& instances)
std::vector<
std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, NDHWC, NDHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
......
......@@ -139,6 +139,10 @@ bool profile_groupnorm_impl(int do_verification,
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment