Commit 05fc2f8e authored by ltqin's avatar ltqin
Browse files

add sourecode interface to factory

parent cf33526e
...@@ -138,7 +138,7 @@ int main(int argc, char* argv[]) ...@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
std::array<std::vector<ck::index_t>, 2>{ std::array<std::vector<ck::index_t>, 2>{
d00_gs_ms_ns_lengths, d01_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d00_gs_ms_ns_lengths, d01_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 2>{ std::array<std::vector<ck::index_t>, 2>{
d01_gs_ms_ns_strides, d01_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d00_gs_ms_ns_strides, d01_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides {}, // acc1_biases_gs_ms_os_strides
AElementOp{}, AElementOp{},
...@@ -210,7 +210,7 @@ int main(int argc, char* argv[]) ...@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
std::array<std::vector<ck::index_t>, 2>{ std::array<std::vector<ck::index_t>, 2>{
d00_gs_ms_ns_lengths, d01_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d00_gs_ms_ns_lengths, d01_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 2>{ std::array<std::vector<ck::index_t>, 2>{
d01_gs_ms_ns_strides, d01_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d00_gs_ms_ns_strides, d01_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides {}, // acc1_biases_gs_ms_os_strides
AElementOp{}, AElementOp{},
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough; using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleMask;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough; using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -23,6 +23,7 @@ using ADataType = ck::half_t; ...@@ -23,6 +23,7 @@ using ADataType = ck::half_t;
using B0DataType = ck::half_t; using B0DataType = ck::half_t;
using B1DataType = ck::half_t; using B1DataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using D00DataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
struct SimpleDeviceMem struct SimpleDeviceMem
...@@ -41,7 +42,7 @@ struct SimpleDeviceMem ...@@ -41,7 +42,7 @@ struct SimpleDeviceMem
void* p_mem_; void* p_mem_;
}; };
int main() int main(int argc, char* argv[])
{ {
int G0 = 48; int G0 = 48;
int G1 = 16; int G1 = 16;
...@@ -66,8 +67,13 @@ int main() ...@@ -66,8 +67,13 @@ int main()
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D00 layout [G0, M, G1, N]
std::vector<ck::index_t> d00_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d00_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K); SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K); SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
SimpleDeviceMem d00_device_buf(sizeof(D00DataType) * G0 * G1 * M * N);
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N); SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O); SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
...@@ -81,7 +87,7 @@ int main() ...@@ -81,7 +87,7 @@ int main()
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ck::Tuple<>, ck::Tuple<D00DataType>,
ck::Tuple<>, ck::Tuple<>,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
...@@ -89,11 +95,10 @@ int main() ...@@ -89,11 +95,10 @@ int main()
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
MaskingSpec>; MaskingSpec>;
// get device op instances // get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
ck::tensor_operation::device::instance:: DeviceOp>::GetInstances();
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
...@@ -106,14 +111,15 @@ int main() ...@@ -106,14 +111,15 @@ int main()
// profile device op instances // profile device op instances
std::cout << "Run all instances and do timing" << std::endl; std::cout << "Run all instances and do timing" << std::endl;
for(size_t i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(), b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(), c_device_buf.GetDeviceBuffer(),
{}, // p_acc0_biases std::array<void*, 1>{d00_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
...@@ -123,13 +129,15 @@ int main() ...@@ -123,13 +129,15 @@ int main()
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
{}, // acc0_biases_gs_ms_ns_lengths std::array<std::vector<ck::index_t>, 1>{
{}, // acc0_biases_gs_ms_ns_strides d00_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d00_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides {}, // acc1_biases_gs_ms_os_strides
AElementOp{}, AElementOp{},
B0ElementOp{}, B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K)}, Acc0ElementOp{1 / sqrtf(K), 0.1},
B1ElementOp{}, B1ElementOp{},
CElementOp{}); CElementOp{});
...@@ -143,7 +151,8 @@ int main() ...@@ -143,7 +151,8 @@ int main()
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(D00DataType) * M * N) *
G0 * G1; G0 * G1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -176,11 +185,12 @@ int main() ...@@ -176,11 +185,12 @@ int main()
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(), b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(), c_device_buf.GetDeviceBuffer(),
{}, // p_acc0_biases std::array<void*, 1>{d00_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
...@@ -190,13 +200,15 @@ int main() ...@@ -190,13 +200,15 @@ int main()
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
{}, // acc0_biases_gs_ms_ns_lengths std::array<std::vector<ck::index_t>, 1>{
{}, // acc0_biases_gs_ms_ns_strides d00_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d00_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides {}, // acc1_biases_gs_ms_os_strides
AElementOp{}, AElementOp{},
B0ElementOp{}, B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K)}, Acc0ElementOp{1 / sqrtf(K), 0.1},
B1ElementOp{}, B1ElementOp{},
CElementOp{}); CElementOp{});
......
...@@ -83,7 +83,10 @@ template <index_t NumDimG, ...@@ -83,7 +83,10 @@ template <index_t NumDimG,
typename C0DEElementwiseOperation, typename C0DEElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value ||
is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value,
bool>::type = false>
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
......
...@@ -11,52 +11,13 @@ ...@@ -11,52 +11,13 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
element_wise::ScaleMask,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment