Commit fa9da1a4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 4c105089 457308e3
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,34 +9,33 @@ ...@@ -9,34 +9,33 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_instances( template <typename InDataType,
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>&); typename AccDataType,
void add_device_softmax_f16_f16_rank4_instances( typename OutDataType,
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>&); index_t Rank,
index_t NumReduceDim>
void add_device_softmax_f32_f32_rank3_instances( struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftmax<InDataType,
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>&); AccDataType,
void add_device_softmax_f32_f32_rank4_instances( OutDataType,
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>&); PassThrough,
PassThrough,
void add_device_softmax_i8_i8_rank3_instances( Rank,
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3>>&); NumReduceDim>>
void add_device_softmax_i8_i8_rank4_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4>>&);
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>>
{ {
using DeviceOp = using DeviceOp = DeviceSoftmax<InDataType,
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>; AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory< ...@@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory<
std::is_same_v<OutDataType, F16>) std::is_same_v<OutDataType, F16>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
add_device_softmax_f16_f16_rank3_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank3_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank3_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank3_reduce3_instances(op_ptrs);
}
else if constexpr(Rank == 4) else if constexpr(Rank == 4)
add_device_softmax_f16_f16_rank4_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs);
}
} }
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> && else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>) std::is_same_v<OutDataType, F32>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
add_device_softmax_f32_f32_rank3_instances(op_ptrs); {
else if constexpr(Rank == 4) if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank4_instances(op_ptrs); add_device_softmax_f32_f32_rank3_reduce1_instances(op_ptrs);
} else if constexpr(NumReduceDim == 2)
else if constexpr(std::is_same_v<InDataType, I8> && std::is_same_v<AccDataType, F32> && add_device_softmax_f32_f32_rank3_reduce2_instances(op_ptrs);
std::is_same_v<OutDataType, I8>) else if constexpr(NumReduceDim == 3)
{ add_device_softmax_f32_f32_rank3_reduce3_instances(op_ptrs);
if constexpr(Rank == 3) }
add_device_softmax_i8_i8_rank3_instances(op_ptrs);
else if constexpr(Rank == 4) else if constexpr(Rank == 4)
add_device_softmax_i8_i8_rank4_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f32_f32_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f32_f32_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs);
}
} }
return op_ptrs; return op_ptrs;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce1_instances( void add_device_softmax_f16_f16_rank3_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce2_instances( void add_device_softmax_f16_f16_rank3_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce3_instances( void add_device_softmax_f16_f16_rank3_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce1_instances( void add_device_softmax_f16_f16_rank4_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 1>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce2_instances( void add_device_softmax_f16_f16_rank4_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 2>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce3_instances( void add_device_softmax_f16_f16_rank4_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 3>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce4_instances( void add_device_softmax_f16_f16_rank4_reduce4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 4>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple> #include <tuple>
...@@ -16,7 +16,6 @@ template <index_t Rank, index_t Reduce> ...@@ -16,7 +16,6 @@ template <index_t Rank, index_t Reduce>
using device_softmax_f16_f16_instances = std::tuple< using device_softmax_f16_f16_instances = std::tuple<
// clang-format off // clang-format off
// InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
// fallback kernel
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>,
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>,
...@@ -33,6 +32,13 @@ using device_softmax_f16_f16_instances = std::tuple< ...@@ -33,6 +32,13 @@ using device_softmax_f16_f16_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <index_t Rank, index_t Reduce>
using device_softmax_f16_f16_generic_instance = std::tuple<
// clang-format off
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 64, 8, 8, 1, 1, 1, 1, 1>
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f32_f32_rank3_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>& instances);
void add_device_softmax_f32_f32_rank4_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment