Commit 7402fcbe authored by ltqin's avatar ltqin
Browse files

add client example for gemm_bias_gemm

parent 22e7a408
......@@ -3,3 +3,9 @@ target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_o
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_mask fused_attention_mask.cpp)
target_link_libraries(client_fused_attention_mask PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_bias_mask fused_attention_bias_mask.cpp)
target_link_libraries(client_fused_attention_bias_mask PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleBiasMask;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
constexpr static auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
using ADataType = ck::half_t;
using B0DataType = ck::half_t;
using B1DataType = ck::half_t;
using CDataType = ck::half_t;
using D00DataType = ck::half_t;
using D01DataType = int32_t;
using AccDataType = float;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
int G0 = 48;
int G1 = 16;
int M = 1024;
int N = 1024;
int K = 64;
int O = 64;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D00 layout [G0, M, G1, N]
std::vector<ck::index_t> d00_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d00_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
// D01 layout [G0, M, G1, N]
std::vector<ck::index_t> d01_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d01_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
SimpleDeviceMem d00_device_buf(sizeof(D00DataType) * G0 * G1 * M * N);
SimpleDeviceMem d01_device_buf(sizeof(D01DataType) * G0 * G1 * M * N);
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D00DataType, D01DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 2>{d00_device_buf.GetDeviceBuffer(),
d01_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 2>{
d00_gs_ms_ns_lengths, d01_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 2>{
d01_gs_ms_ns_strides, d01_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K), 0.1},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(D00DataType) * M * N * 2) *
G0 * G1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best instance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 2>{d00_device_buf.GetDeviceBuffer(),
d01_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 2>{
d00_gs_ms_ns_lengths, d01_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 2>{
d01_gs_ms_ns_strides, d01_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K), 0.1},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleMask;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
constexpr static auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
using ADataType = ck::half_t;
using B0DataType = ck::half_t;
using B1DataType = ck::half_t;
using CDataType = ck::half_t;
using D0DataType = int32_t;
using AccDataType = float;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
int G0 = 48;
int G1 = 16;
int M = 1024;
int N = 1024;
int K = 64;
int O = 64;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
using DeviceOp =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D0DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K), 0.1},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(D0DataType) * M * N) *
G0 * G1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best instance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K), 0.1},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16, int32_t>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16, int32_t>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
op_ptrs);
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,6 +3,6 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
)
......@@ -24,8 +24,8 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Scale = ck::tensor_operation::element_wise::Scale;
using ScaleMask = ck::tensor_operation::element_wise::ScaleMask;
using ScaleBiasMask = ck::tensor_operation::element_wise::ScaleBiasMask;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
......@@ -68,8 +68,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on
>;
// f16 PassThrough masking
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
// f16 ScaleMask masking
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -80,11 +80,11 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
PassThrough,
ScaleMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
......@@ -100,13 +100,13 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
1,
F16,
F32,
ck::Tuple<>,
PassThrough,
ck::Tuple<int32_t>,
ScaleMask,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
// f16 PassThrough disable masking
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
// f16 ScaleMask disable masking
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -117,11 +117,11 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
PassThrough,
ScaleMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
......@@ -137,13 +137,13 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
1,
F16,
F32,
ck::Tuple<>,
PassThrough,
ck::Tuple<int32_t>,
ScaleMask,
MaskingSpecialization::MaskDisabled>{});
}
// f16
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
// f16 ScaleBiasMask masking
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -154,11 +154,11 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<F16, int32_t>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
......@@ -174,12 +174,13 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
1,
F16,
F32,
ck::Tuple<>,
Scale,
ck::Tuple<F16, int32_t>,
ScaleBiasMask,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
// f16 ScaleBiasMask disable masking
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
......@@ -190,11 +191,11 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<F16, int32_t>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
ScaleBiasMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
......@@ -210,8 +211,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
1,
F16,
F32,
ck::Tuple<>,
Scale,
ck::Tuple<F16, int32_t>,
ScaleBiasMask,
MaskingSpecialization::MaskDisabled>{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment