Unverified Commit 4fe9c393 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

N-D Tensor Contraction example, instance, and client example (#270)

* adding contraction

* add contraction example

* update examle

* update example

* format

* update readme

* clean header

* clean header

* contraction with multiple D

* rename

* fix naming issue; add instances for contraction+bilinear

* change assumed virtual layout of contraction; add client example

* update example

* update

* contraction+scale

* use type_convert

* rename
parent 334361cb
add_executable(client_contraction_scale contraction_scale.cpp)
target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations)
add_executable(client_contraction_bilinear contraction_bilinear.cpp)
target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Bilinear;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float alpha = 1.f;
float beta = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 25)
{
const ck::index_t M0 = std::stoi(argv[1]);
const ck::index_t M1 = std::stoi(argv[2]);
const ck::index_t N0 = std::stoi(argv[3]);
const ck::index_t N1 = std::stoi(argv[4]);
const ck::index_t K0 = std::stoi(argv[5]);
const ck::index_t K1 = std::stoi(argv[6]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
d_ms_ns_lengths = {M0, M1, N0, N1};
d_ms_ns_strides = {
std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])};
alpha = std::stof(argv[23]);
beta = std::stof(argv[24]);
}
else
{
printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg23 to 24: alpha, beta\n");
exit(0);
}
auto f_tensor_space_size = [](auto lengths, auto strides) {
std::size_t space_size = 1;
for(std::size_t i = 0; i < lengths.size(); ++i)
{
space_size += (lengths[i] - 1) * strides[i];
}
return space_size;
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) *
f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
SimpleDeviceMem b_device_buf(sizeof(BDataType) *
f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
SimpleDeviceMem d_device_buf(sizeof(DDataType) *
f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides));
SimpleDeviceMem e_device_buf(sizeof(EDataType) *
f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{alpha, beta};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp"
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Scale;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float scale = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 20)
{
const ck::index_t M0 = std::stoi(argv[1]);
const ck::index_t M1 = std::stoi(argv[2]);
const ck::index_t N0 = std::stoi(argv[3]);
const ck::index_t N1 = std::stoi(argv[4]);
const ck::index_t K0 = std::stoi(argv[5]);
const ck::index_t K1 = std::stoi(argv[6]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
scale = std::stof(argv[19]);
}
else
{
printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg19: scale\n");
exit(0);
}
auto f_tensor_space_size = [](auto lengths, auto strides) {
std::size_t space_size = 1;
for(std::size_t i = 0; i < lengths.size(); ++i)
{
space_size += (lengths[i] - 1) * strides[i];
}
return space_size;
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) *
f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
SimpleDeviceMem b_device_buf(sizeof(BDataType) *
f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
SimpleDeviceMem e_device_buf(sizeof(EDataType) *
f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{scale};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
...@@ -9,3 +9,4 @@ message(STATUS "Build with HIP ${hip_VERSION}") ...@@ -9,3 +9,4 @@ message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(01_gemm) add_subdirectory(01_gemm)
add_subdirectory(02_gemm_add_add_fastgelu) add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory(03_gemm_layernorm) add_subdirectory(03_gemm_layernorm)
add_subdirectory(04_contraction)
...@@ -213,15 +213,15 @@ int main(int argc, char* argv[]) ...@@ -213,15 +213,15 @@ int main(int argc, char* argv[])
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d_m_n_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
e_m_n_device_buf.ToDevice(e_m_n_device_result.mData.data()); e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -231,10 +231,10 @@ int main(int argc, char* argv[]) ...@@ -231,10 +231,10 @@ int main(int argc, char* argv[])
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_m_n_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
...@@ -266,7 +266,7 @@ int main(int argc, char* argv[]) ...@@ -266,7 +266,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
...@@ -296,7 +296,7 @@ int main(int argc, char* argv[]) ...@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
} }
} }
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
......
...@@ -191,14 +191,14 @@ int main(int argc, char* argv[]) ...@@ -191,14 +191,14 @@ int main(int argc, char* argv[])
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d_m_n_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -210,10 +210,10 @@ int main(int argc, char* argv[]) ...@@ -210,10 +210,10 @@ int main(int argc, char* argv[])
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_m_n_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
...@@ -246,7 +246,7 @@ int main(int argc, char* argv[]) ...@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
......
...@@ -156,16 +156,16 @@ int main(int argc, char* argv[]) ...@@ -156,16 +156,16 @@ int main(int argc, char* argv[])
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -175,11 +175,11 @@ int main(int argc, char* argv[]) ...@@ -175,11 +175,11 @@ int main(int argc, char* argv[])
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(), std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_m_n_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
e_m_n_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
...@@ -239,7 +239,7 @@ int main(int argc, char* argv[]) ...@@ -239,7 +239,7 @@ int main(int argc, char* argv[])
} }
} }
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
......
...@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType acc = AccDataType acc = ck::type_convert<AccDataType>(c_m_n(m, n)) +
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n)); ck::type_convert<AccDataType>(bias_n(n));
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n)); AccDataType c1 = ck::type_convert<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc); c_element_op(acc, acc);
c1_element_op(c1, c1); c1_element_op(c1, c1);
acc += c1; acc += c1;
c_m_n(m, n) = static_cast<CDataType>(acc); c_m_n(m, n) = ck::type_convert<CDataType>(acc);
} }
// reduce_mean and reduce_square_mean // reduce_mean and reduce_square_mean
...@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
AccDataType out_acc = 0; AccDataType out_acc = 0;
layerNormInst(out_acc, layerNormInst(out_acc,
static_cast<AccDataType>(c_m_n(m, n)), ck::type_convert<AccDataType>(c_m_n(m, n)),
static_cast<AccDataType>(mean_m(m)), ck::type_convert<AccDataType>(mean_m(m)),
static_cast<AccDataType>(meanSquare_m(m)), ck::type_convert<AccDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)), ck::type_convert<AccDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n))); ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc); out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
} }
} }
} }
......
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
# Instructions for ```example_contraction_bilinear_xdl_fp32```
## Run
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
```
a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096}
c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
```
This diff is collapsed.
This diff is collapsed.
...@@ -44,3 +44,4 @@ add_subdirectory(22_cgemm) ...@@ -44,3 +44,4 @@ add_subdirectory(22_cgemm)
add_subdirectory(23_softmax) add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm_c_permute) add_subdirectory(24_batched_gemm_c_permute)
add_subdirectory(25_gemm_bias_c_permute) add_subdirectory(25_gemm_bias_c_permute)
add_subdirectory(26_contraction)
...@@ -102,7 +102,12 @@ ...@@ -102,7 +102,12 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 #define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// experimental feature: buffer load/store/atomic-add/ OOB trick // experimental feature: buffer load/store/atomic-add/ OOB trick
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_ms_ks_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -11,11 +11,14 @@ namespace ck { ...@@ -11,11 +11,14 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// input : A[M, K], B[K, N], // GEMM:
// input : D0[M, N], D1[M, N], ... // input : A[M, K], B[K, N],
// output : E[M, N] // input : D0[M, N], D1[M, N], ...
// C = a_op(A) * b_op(B) // output : E[M, N]
// E = cde_op(C, D0, D1, ...) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DELayout,
......
...@@ -88,12 +88,15 @@ namespace ck { ...@@ -88,12 +88,15 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// input : A[M, K], or A[K, N] // GEMM:
// input : B[K, N], or A[N, K] // input : A[AK0, M, AK1]
// input : D0[M, N], D1[M, N], ... // input : B[AK0, N, AK1]
// output : E[M, N] // input : D0[M, N], D1[M, N], ...
// C = a_op(A) * b_op(B) // output : E[M, N]
// E = cde_op(C, D0, D1, ...) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DELayout,
...@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
...@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
...@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n = const auto d_grid_desc_m_n =
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
} }
// ck::Tuple<const DsDataType*...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< StaticallyIndexedArray<
...@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
......
...@@ -24,9 +24,25 @@ struct PassThrough ...@@ -24,9 +24,25 @@ struct PassThrough
}; };
}; };
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = scale_ * x;
};
float scale_;
};
struct UnaryDivide struct UnaryDivide
{ {
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider){}; __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment