Commit 6c2b74de authored by Chao Liu's avatar Chao Liu
Browse files

change assumed virtual layout of contraction; add client example

parent f6c42793
add_executable(client_contraction_bilinear contraction_bilinear.cpp)
target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Bilinear;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float alpha = 1.f;
float beta = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 25)
{
const ck::index_t M0 = std::stoi(argv[1]);
const ck::index_t M1 = std::stoi(argv[2]);
const ck::index_t N0 = std::stoi(argv[3]);
const ck::index_t N1 = std::stoi(argv[4]);
const ck::index_t K0 = std::stoi(argv[5]);
const ck::index_t K1 = std::stoi(argv[6]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
d_ms_ns_lengths = {M0, M1, N0, N1};
d_ms_ns_strides = {
std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])};
alpha = std::stof(argv[23]);
beta = std::stof(argv[24]);
}
else
{
printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg23 to 24: alpha, beta\n");
exit(0);
}
auto f_tensor_space_size = [](auto lengths, auto strides) {
std::size_t space_size = 1;
for(std::size_t i = 0; i < lengths.size(); ++i)
{
space_size += (lengths[i] - 1) * strides[i];
}
return space_size;
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) *
f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
SimpleDeviceMem b_device_buf(sizeof(BDataType) *
f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
SimpleDeviceMem d_device_buf(sizeof(DDataType) *
f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides));
SimpleDeviceMem e_device_buf(sizeof(EDataType) *
f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{alpha, beta};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
......@@ -9,3 +9,4 @@ message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(01_gemm)
add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory(03_gemm_layernorm)
add_subdirectory(05_contraction_bilinear)
......@@ -213,15 +213,15 @@ int main(int argc, char* argv[])
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m_n_device_buf.ToDevice(d_m_n.mData.data());
e_m_n_device_buf.ToDevice(e_m_n_device_result.mData.data());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -231,10 +231,10 @@ int main(int argc, char* argv[])
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_m_n_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(),
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
......@@ -266,7 +266,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
......@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
}
}
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
}
......
......@@ -191,14 +191,14 @@ int main(int argc, char* argv[])
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m_n_device_buf.ToDevice(d_m_n.mData.data());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -210,10 +210,10 @@ int main(int argc, char* argv[])
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_m_n_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(),
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
......@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
......
......@@ -156,16 +156,16 @@ int main(int argc, char* argv[])
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -175,11 +175,11 @@ int main(int argc, char* argv[])
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
d1_m_n_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(),
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
......@@ -239,7 +239,7 @@ int main(int argc, char* argv[])
}
}
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
}
......
......@@ -69,13 +69,13 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ks_ns_{b_ks_ns},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -84,7 +84,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ks_ns_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
......@@ -115,7 +115,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
arg.a_element_op_(
v_a, static_cast<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, static_cast<const AccDataType>(arg.b_ks_ns_(k0, k1, n0, n1)));
v_b, static_cast<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
......@@ -157,13 +157,13 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ks_ns,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ks_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -192,42 +192,86 @@ int main(int argc, char* argv[])
int init_method = 1;
bool time_kernel = false;
if(argc == 4)
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float alpha = 1.f;
float beta = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 28)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t M0 = std::stoi(argv[4]);
const ck::index_t M1 = std::stoi(argv[5]);
const ck::index_t N0 = std::stoi(argv[6]);
const ck::index_t N1 = std::stoi(argv[7]);
const ck::index_t K0 = std::stoi(argv[8]);
const ck::index_t K1 = std::stoi(argv[9]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
d_ms_ns_lengths = {M0, M1, N0, N1};
d_ms_ns_strides = {
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
alpha = std::stof(argv[26]);
beta = std::stof(argv[27]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg26 to 27: alpha, beta\n");
exit(0);
}
const float alpha = 1;
const float beta = 1;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[K0, K1, N0, N1]
std::vector<ck::index_t> b_ks_ns_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ks_ns_strides{128, 1, 524288, 4096};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
Tensor<BDataType> b_ks_ns(
std::vector<std::size_t>(b_ks_ns_lengths.begin(), b_ks_ns_lengths.end()),
std::vector<std::size_t>(b_ks_ns_strides.begin(), b_ks_ns_strides.end()));
Tensor<BDataType> b_ns_ks(
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> d_ms_ns(
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
......@@ -239,7 +283,7 @@ int main(int argc, char* argv[])
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ks_ns: " << b_ks_ns.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
......@@ -248,55 +292,50 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ks_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ks_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b_ks_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_ms_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_ms_ks_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_ks_ns_device_buf(sizeof(BDataType) * b_ks_ns.mDesc.GetElementSpace());
DeviceMem d_ms_ns_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace());
DeviceMem e_ms_ns_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
a_ms_ks_device_buf.ToDevice(a_ms_ks.mData.data());
b_ks_ns_device_buf.ToDevice(b_ks_ns.mData.data());
d_ms_ns_device_buf.ToDevice(d_ms_ns.mData.data());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
d_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero
e_ms_ns_device_buf.SetZero();
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument =
op.MakeArgument(a_ms_ks_device_buf.GetDeviceBuffer(),
b_ks_ns_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_ms_ns_device_buf.GetDeviceBuffer()},
e_ms_ns_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ks_ns_lengths,
b_ks_ns_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
......@@ -333,7 +372,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_ms_ns_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
if(do_verification)
{
......@@ -356,7 +395,7 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ks_ns, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
......
......@@ -20,10 +20,10 @@ namespace device {
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
......@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
std::vector<index_t> a_ms_ks_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
......
......@@ -97,10 +97,10 @@ namespace device {
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[K0, K1, K2, ..., N0, N1, N2...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
......@@ -164,19 +164,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// assume A[M0, M1, M2, ..., K0, K1, K2...]
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_lengths_vec,
const std::vector<index_t>& a_strides_vec)
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_ms_ks_lengths_vec,
const std::vector<index_t>& a_ms_ks_strides_vec)
{
assert(a_lengths_vec.size() == NumDimM + NumDimK &&
a_strides_vec.size() == NumDimM + NumDimK);
assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
const auto to_tuple = [&](auto& vec, auto Num) {
return generate_tuple([&](auto i) { return vec[i]; }, Num);
};
const auto a_lengths = to_tuple(a_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_strides = to_tuple(a_strides_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
......@@ -186,13 +186,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_lengths, mDimIds);
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_lengths, kDimIds);
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks = make_naive_tensor_descriptor(a_lengths, a_strides);
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
......@@ -292,42 +293,43 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
}
// assume B[K0, K1, K2, ..., N0, N1, N2...]
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_lengths_vec,
const std::vector<index_t>& b_strides_vec)
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_ns_ks_lengths_vec,
const std::vector<index_t>& b_ns_ks_strides_vec)
{
assert(b_lengths_vec.size() == NumDimN + NumDimK &&
b_strides_vec.size() == NumDimN + NumDimK);
assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
const auto to_tuple = [&](auto& vec, auto Num) {
return generate_tuple([&](auto i) { return vec[i]; }, Num);
};
const auto b_lengths = to_tuple(b_lengths_vec, Number<NumDimN + NumDimK>{});
const auto b_strides = to_tuple(b_strides_vec, Number<NumDimN + NumDimK>{});
// dimension Ids for K0, K1, ...
constexpr auto kDimIds = typename arithmetic_sequence_gen<0, NumDimK, 1>::type{};
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number<NumDimN + NumDimK>{});
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number<NumDimN + NumDimK>{});
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimK, NumDimK + NumDimN, 1>::type{};
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(b_lengths, kDimIds);
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_lengths, nDimIds);
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
// naive tensor B[K0, K1, K2..., N0, N1, N2, ...]
const auto b_grid_desc_ks_ns = make_naive_tensor_descriptor(b_lengths, b_strides);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[KRaw = K0 * K1 * K2 * ... , NRaw = N0 * N1 * N2 * ... ]
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ks_ns,
make_tuple(make_merge_transform(kLengths), make_merge_transform(nLengths)),
make_tuple(kDimIds, nDimIds),
make_tuple(Sequence<1>{}, Sequence<0>{}));
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
......@@ -421,18 +423,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_lengths_vec,
const std::vector<index_t>& e_strides_vec)
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_vec,
const std::vector<index_t>& e_ms_ns_strides_vec)
{
assert(e_lengths_vec.size() == NumDimM + NumDimN &&
e_strides_vec.size() == NumDimM + NumDimN);
assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto Num) {
return generate_tuple([&](auto i) { return vec[i]; }, Num);
};
const auto e_lengths = to_tuple(e_lengths_vec, Number<NumDimM + NumDimN>{});
const auto e_strides = to_tuple(e_strides_vec, Number<NumDimM + NumDimN>{});
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number<NumDimM + NumDimN>{});
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number<NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
......@@ -442,13 +444,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(e_lengths, mDimIds);
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(e_lengths, nDimIds);
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_ms_ns = make_naive_tensor_descriptor(e_lengths, e_strides);
const auto e_grid_desc_ms_ns =
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
......@@ -564,14 +567,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
std::vector<index_t> a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
......@@ -579,10 +582,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_lengths, a_strides)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_lengths, b_strides)},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_ms_ns_lengths, a_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_lengths, e_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
......@@ -616,7 +621,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N(ds_lengths[i], ds_strides[i]);
DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -624,24 +629,33 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
});
}
// for sanity check of vector load/store
a_mz_length_ = a_lengths[NumDimM - 1];
a_mz_stride_ = a_strides[NumDimM - 1];
// for sanity check of vector memory access
a_mz_length_ = a_ms_ns_lengths[NumDimM - 1];
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1];
a_kz_length_ = a_ms_ns_lengths[NumDimM + NumDimK - 1];
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1];
a_kz_length_ = a_lengths[NumDimM + NumDimK - 1];
a_kz_stride_ = a_strides[NumDimM + NumDimK - 1];
b_nz_length_ = b_ns_ks_lengths[NumDimN - 1];
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1];
b_nz_length_ = b_lengths[NumDimK - 1];
b_nz_stride_ = b_strides[NumDimK - 1];
b_kz_length_ = b_ns_ks_lengths[NumDimN + NumDimK - 1];
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
b_kz_length_ = b_lengths[NumDimK + NumDimN - 1];
b_kz_stride_ = b_strides[NumDimK + NumDimN - 1];
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_mz_length_[i] = ds_ms_ns_lengths[i][NumDimM - 1];
ds_mz_stride_[i] = ds_ms_ns_strides[i][NumDimM - 1];
ds_nz_length_[i] = ds_ms_ns_lengths[i][NumDimM + NumDimN - 1];
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1];
}
e_mz_length_ = b_lengths[NumDimM - 1];
e_mz_stride_ = b_strides[NumDimM - 1];
e_mz_length_ = e_ms_ns_lengths[NumDimM - 1];
e_mz_stride_ = e_ms_ns_strides[NumDimM - 1];
e_nz_length_ = b_lengths[NumDimM + NumDimN - 1];
e_nz_stride_ = b_strides[NumDimM + NumDimN - 1];
e_nz_length_ = e_ms_ns_lengths[NumDimM + NumDimN - 1];
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
}
// private:
......@@ -682,6 +696,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
index_t b_nz_stride_;
index_t b_kz_length_;
index_t b_kz_stride_;
std::array<index_t, NumDTensor> ds_mz_length_;
std::array<index_t, NumDTensor> ds_mz_stride_;
std::array<index_t, NumDTensor> ds_nz_length_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_mz_length_;
index_t e_mz_stride_;
index_t e_nz_length_;
......@@ -823,14 +841,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
std::vector<index_t> a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
......@@ -839,14 +857,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b,
p_ds,
p_e,
a_lengths,
a_strides,
b_lengths,
b_strides,
ds_lengths,
ds_strides,
e_lengths,
e_strides,
a_ms_ns_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
ds_ms_ns_lengths,
ds_ms_ns_strides,
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op};
......@@ -860,14 +878,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_lengths,
std::vector<index_t> b_strides,
std::array<std::vector<index_t>, NumDTensor> ds_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_strides,
std::vector<index_t> e_lengths,
std::vector<index_t> e_strides,
std::vector<index_t> a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
......@@ -876,14 +894,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b,
p_ds,
p_e,
a_lengths,
a_strides,
b_lengths,
b_strides,
ds_lengths,
ds_strides,
e_lengths,
e_strides,
a_ms_ns_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
ds_ms_ns_lengths,
ds_ms_ns_strides,
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
......
......@@ -19,6 +19,8 @@ using BF16 = ck::bhalf_t;
using F16_TUPLE = ck::Tuple<F16>;
using F16_F16_TUPLE = ck::Tuple<F16, F16>;
using F32_TUPLE = ck::Tuple<F32>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F32,
F32,
F32_TUPLE,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F32,
F32,
F32_TUPLE,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F32,
F32,
F32_TUPLE,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F32,
F32,
F32_TUPLE,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
// Contraction + Bilinear
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContractionMultipleD<
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>>
{
using DeviceOp = DeviceContractionMultipleD<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<DDataType, float> && is_same_v<EDataType, float>)
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
// n0, n1] k/k/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// k/k/n/n are the fast changing dimension for A/B/D/E
using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = std::tuple<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
......
......@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
// n0, n1] k/n/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// k/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = std::tuple<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
......
......@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
// n0, n1] m/k/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/k/n/n are the fast changing dimension for A/B/D/E
using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = std::tuple<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
......
......@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
// n0, n1] m/n/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = std::tuple<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment