Commit 29448ffd authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

merge from develop and revisison for pr#881

parents 9223a5e2 8f84a012
File mode changed from 100644 to 100755
...@@ -20,12 +20,15 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) ...@@ -20,12 +20,15 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
*/ */
struct DeviceMem struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size); DeviceMem(std::size_t mem_size);
void Realloc(std::size_t mem_size);
void* GetDeviceBuffer() const; void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const; void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const; void SetZero() const;
template <typename T> template <typename T>
void SetValue(T x) const; void SetValue(T x) const;
......
...@@ -102,9 +102,10 @@ struct FillMonotonicSeq ...@@ -102,9 +102,10 @@ struct FillMonotonicSeq
} }
template <typename ForwardRange> template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<decltype( auto operator()(ForwardRange&& range) const
std::declval<const FillMonotonicSeq&>()(std::begin(std::forward<ForwardRange>(range)), -> std::void_t<decltype(std::declval<const FillMonotonicSeq&>()(
std::end(std::forward<ForwardRange>(range))))> std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{ {
(*this)(std::begin(std::forward<ForwardRange>(range)), (*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))); std::end(std::forward<ForwardRange>(range)));
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp" #include "ck/utility/span.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/ranges.hpp" #include "ck/library/utility/ranges.hpp"
......
...@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t> ...@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_2<ck::f8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f8_t>(tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t> ...@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_3<ck::f8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f8_t>(fp32_tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -12,9 +12,46 @@ set(CK_DEVICE_INSTANCES) ...@@ -12,9 +12,46 @@ set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list}) FOREACH(subdir_path ${dir_list})
set(target_dir) set(target_dir)
IF(IS_DIRECTORY "${subdir_path}") IF(IS_DIRECTORY "${subdir_path}")
get_filename_component(target_dir ${subdir_path} NAME) set(cmake_instance)
add_subdirectory(${target_dir}) file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>) set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
#message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
#message("fp32 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
#message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
#message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
#message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if(add_inst EQUAL 1)
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
ENDIF() ENDIF()
ENDFOREACH() ENDFOREACH()
......
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I32 = int32_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using device_avgpool_bwd_ndhwc_f16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_bf16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_f32_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
#include <vector> namespace ck {
namespace tensor_operation {
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" namespace device {
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" namespace instance {
namespace ck { void add_device_avgpool_bwd_ndhwc_bf16_instances(
namespace tensor_operation { std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, BF16, BF16, NDHWC, NDHWC>>>& instances)
namespace device { {
namespace instance { add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_bf16_instances{});
}
void add_device_softmax_i8_i8_rank4_reduce1_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4>>& instances); } // namespace instance
} // namespace device
} // namespace instance } // namespace tensor_operation
} // namespace device } // namespace ck
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
#include <vector> namespace ck {
namespace tensor_operation {
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" namespace device {
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" namespace instance {
namespace ck { void add_device_avgpool_bwd_ndhwc_f16_instances(
namespace tensor_operation { std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F16, F16, NDHWC, NDHWC>>>& instances)
namespace device { {
namespace instance { add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f16_instances{});
}
void add_device_softmax_i8_i8_rank3_reduce2_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3>>& instances); } // namespace instance
} // namespace device
} // namespace instance } // namespace tensor_operation
} // namespace device } // namespace ck
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
#include <vector> namespace ck {
namespace tensor_operation {
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" namespace device {
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" namespace instance {
namespace ck { void add_device_avgpool_bwd_ndhwc_f32_instances(
namespace tensor_operation { std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F32, F32, NDHWC, NDHWC>>>& instances)
namespace device { {
namespace instance { add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f32_instances{});
}
void add_device_softmax_i8_i8_rank3_reduce1_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3>>& instances); } // namespace instance
} // namespace device
} // namespace instance } // namespace tensor_operation
} // namespace device } // namespace ck
} // namespace tensor_operation
} // namespace ck
add_instance_library(device_batched_gemm_instance set(BATCHED_GEMM_INSTANCES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
) device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment