Commit 8d4b916e authored by ltqin's avatar ltqin
Browse files

bias mask out of iner op

parent a483948c
......@@ -4,8 +4,8 @@ target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_o
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_bias_mask fused_attention_bias_mask.cpp)
target_link_libraries(client_fused_attention_bias_mask PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_bias_mask_no_lib fused_attention_bias_mask_no_lib.cpp)
target_link_libraries(client_fused_attention_bias_mask_no_lib PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_no_lib fused_attention_no_lib.cpp)
target_link_libraries(client_fused_attention_no_lib PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_mask_no_lib fused_attention_mask_no_lib.cpp)
target_link_libraries(client_fused_attention_mask_no_lib PRIVATE composable_kernel::device_operations)
......@@ -5,14 +5,46 @@
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
struct ScaleBiasMask
{
ScaleBiasMask(float scale, float mask_filter_value)
: scale_(scale), mask_filter_value_(mask_filter_value)
{
}
// biased, masked
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void
operator()(Y& y, const X0& x, const X1& bias, const X2& mask) const;
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const ck::half_t& bias, const int16_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const ck::half_t& bias, const ck::half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleBiasMask;
using Acc0ElementOp = ScaleBiasMask;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
......
......@@ -5,7 +5,7 @@
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......
......@@ -389,37 +389,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
struct ScaleBiasMask
{
ScaleBiasMask(float scale, float mask_filter_value)
: scale_(scale), mask_filter_value_(mask_filter_value)
{
}
// biased, masked
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void
operator()(Y& y, const X0& x, const X1& bias, const X2& mask) const;
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& bias, const int16_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& bias, const half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -11,13 +11,99 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -38,7 +124,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -59,7 +145,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -80,7 +166,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -100,85 +186,67 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
op_ptrs);
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
op_ptrs);
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,6 +3,5 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleBiasMask = ck::tensor_operation::element_wise::ScaleBiasMask;
// f16 ScaleBiasMask masking
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
F16,
F32,
ck::Tuple<F16, F16>,
ScaleBiasMask,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
// f16 ScaleBiasMask disable masking
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
F16,
F32,
ck::Tuple<F16, F16>,
ScaleBiasMask,
MaskingSpecialization::MaskDisabled>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......
......@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment