Commit 6d9a07d7 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents b30d416c a776978c
......@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
y_dev.FromDevice(y.mData.data());
bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 5e-3, 5e-3);
if(do_log)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
#if DEBUG_LOG
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
#endif // DEBUG_LOG
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_device_buf[i]->GetDeviceBuffer(),
b_device_buf[i]->GetDeviceBuffer(),
{},
c_device_buf[i]->GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideCs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_c,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
DeviceMem grouped_gemm_kernel_args_dev(
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString();
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(kbatch > 0)
{
kbatch_list = {kbatch};
}
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr);
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
}
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
<< std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
namespace ck {
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
FunctorA functor_a,
FunctorB functor_b,
float scale)
{
std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1];
std::size_t H = A_nchw.mDesc.GetLengths()[2];
std::size_t W = A_nchw.mDesc.GetLengths()[3];
for(std::size_t w = 0; w < W; ++w)
for(std::size_t h = 0; h < H; ++h)
for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n)
{
using tmp_type = ck::remove_reference_t<decltype(B_nhwc(0, 0))>;
tmp_type tmp_val = 0;
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val);
}
}
template <typename ADataType, typename BDataType, index_t NumDim>
bool test_permute_scale_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> lengths)
{
bool pass = true;
using ElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
using Scale = ck::tensor_operation::element_wise::Scale;
float scale = 2.f;
index_t N = lengths[0];
index_t C = lengths[1];
index_t H = lengths[2];
index_t W = lengths[3];
std::vector<ck::index_t> nchw = {N, C, H, W};
std::vector<ck::index_t> nhwc = {N, H, W, C};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
Tensor<BDataType> host_b(nhwc);
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {1,
static_cast<int>(nchw[0]),
static_cast<int>(nchw[0] * nchw[1]),
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
std::array<ck::index_t, 4> b_strides = {1,
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
static_cast<int>(nhwc[0]),
static_cast<int>(nhwc[0] * nhwc[1])};
ck::ranges::copy(nchw, ab_lengths.begin());
std::cout << "A: " << a.mDesc << std::endl;
std::cout << "B: " << b.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1: a.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2}); break;
default: // a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
auto i = 0;
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n)
{
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i;
i = dis(gen);
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
ck::Tuple<BDataType>,
ElementOp,
UnaryOp,
Scale,
NumDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
if(do_verification)
{
host_elementwise4D(host_b, a, ElementOp{}, UnaryOp{}, scale);
}
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
ElementOp{},
UnaryOp{},
Scale{scale});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
b_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
pass &= ck::utils::check_err(
b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
}
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_instance_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", lengths, ",") << ", ";
std::cout << "best perf = " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
}
return true;
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
namespace ck {
template <typename HostTensorA,
typename HostTensorB,
typename AElementOp,
typename BElementOp,
typename ScaleElementOp>
void reference_permute_scale(HostTensorB& b_tensor,
const HostTensorA& a_tensor,
AElementOp a_tensor_op,
BElementOp b_tensor_op,
ScaleElementOp scale_op)
{
b_tensor.ForEach([&](auto& self, auto idx) {
auto tmp_val = a_tensor(idx);
b_tensor_op(tmp_val, tmp_val);
scale_op(tmp_val, tmp_val);
a_tensor_op(self(idx), tmp_val);
});
}
namespace profiler {
template <typename ADataType, typename BDataType, index_t NumDim>
bool profile_permute_scale_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> lengths_vector,
std::vector<index_t> input_strides_vector,
std::vector<index_t> output_strides_vector)
{
bool pass = true;
bool instance_found = false;
using ElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
using Scale = ck::tensor_operation::element_wise::Scale;
float scale = 2.f;
Tensor<ADataType> a(lengths_vector, input_strides_vector);
Tensor<BDataType> b(lengths_vector, output_strides_vector);
Tensor<BDataType> host_b(lengths_vector, output_strides_vector);
std::cout << "A: " << a.mDesc << std::endl;
std::cout << "B: " << b.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1: a.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2}); break;
default: a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
ck::Tuple<BDataType>,
ElementOp,
UnaryOp,
Scale,
NumDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
if(do_verification)
{
reference_permute_scale(host_b, a, ElementOp{}, UnaryOp{}, Scale{scale});
}
auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
std::array<ck::index_t, NumDim> lengths{};
std::array<ck::index_t, NumDim> input_strides{};
std::array<ck::index_t, NumDim> output_strides{};
copy(lengths_vector, lengths);
copy(input_strides_vector, input_strides);
copy(output_strides_vector, output_strides);
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(lengths,
{input_strides},
{output_strides},
input,
output,
ElementOp{},
UnaryOp{},
Scale{scale});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
instance_found = true;
b_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
pass &= ck::utils::check_err(
b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
}
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * a.mDesc.GetElementSpaceSize() / sizeof(ADataType);
std::size_t num_btype = sizeof(ADataType) * a.mDesc.GetElementSpaceSize() +
sizeof(BDataType) * b.mDesc.GetElementSpaceSize();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_instance_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
if(time_kernel)
{
std::cout << "Best perf = " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
}
return pass && instance_found;
}
} // namespace profiler
} // namespace ck
......@@ -32,6 +32,7 @@ set(PROFILER_SOURCES
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
)
if(DL_KERNELS)
......@@ -51,6 +52,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif()
......@@ -99,6 +101,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_d
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
......@@ -124,6 +127,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
endif()
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
};
enum struct GemmDataType
{
BF16_I8_BF16, // 0
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3
};
#define OP_NAME "grouped_gemm_fixed_nk"
#define OP_DESC "Grouped GEMM Fixed NK"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: bf16@int8; 1: fp16; 2: fp16@fp8; 3: fp16@int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 1)\n"
<< "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
}
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
1,
n_warmup,
n_iter);
}
#endif
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_fixed_nk);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_permute_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
namespace {
enum struct DataType
{
F32_F32, // 0
F16_F16 // 1
};
#define OP_NAME "permute_scale"
#define OP_DESC "Permute Scale"
static void print_helper_msg()
{
std::cout
// clang-format off
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: Input fp32, Output fp32\n"
<< " 1: Input fp16, Output fp16\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "from arg8: tensor lengths\n"
<< " input strides\n"
<< " output strides\n" << std::endl;
// clang-format on
}
} // namespace
int profile_permute_scale(int argc, char* argv[])
{
constexpr int control_argc = 7;
const int dims_argc = argc - control_argc;
// Number of lenghs, input strides and outputs strides must be equal
if(argc < control_argc && dims_argc % 3 != 0)
{
print_helper_msg();
return 1;
}
const auto data_type = static_cast<DataType>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
const int num_dims = dims_argc / 3;
std::vector<ck::index_t> lengths(num_dims);
std::vector<ck::index_t> input_strides(num_dims);
std::vector<ck::index_t> output_strides(num_dims);
for(int i = 0; i < num_dims; i++)
{
lengths[i] = std::stoi(argv[control_argc + i]);
input_strides[i] = std::stoi(argv[control_argc + num_dims + i]);
output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
}
using F32 = float;
using F16 = ck::half_t;
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
constexpr auto I4 = ck::Number<4>{};
constexpr auto I5 = ck::Number<5>{};
constexpr auto I6 = ck::Number<6>{};
auto profile = [&](auto num_dim_tmp, auto in_type, auto out_type) {
constexpr ck::index_t NDim = num_dim_tmp.value;
using InDataType = decltype(in_type);
using OutDataType = decltype(out_type);
bool pass =
ck::profiler::profile_permute_scale_impl<InDataType, OutDataType, NDim>(do_verification,
init_method,
do_log,
time_kernel,
lengths,
input_strides,
output_strides);
return pass ? 0 : 1;
};
if(num_dims == 1)
{
if(data_type == DataType::F32_F32)
{
return profile(I1, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I1, F16{}, F16{});
}
}
else if(num_dims == 2)
{
if(data_type == DataType::F32_F32)
{
return profile(I2, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I2, F16{}, F16{});
}
}
else if(num_dims == 3)
{
if(data_type == DataType::F32_F32)
{
return profile(I3, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I3, F16{}, F16{});
}
}
else if(num_dims == 4)
{
if(data_type == DataType::F32_F32)
{
return profile(I4, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I4, F16{}, F16{});
}
}
else if(num_dims == 5)
{
if(data_type == DataType::F32_F32)
{
return profile(I5, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I5, F16{}, F16{});
}
}
else if(num_dims == 6)
{
if(data_type == DataType::F32_F32)
{
return profile(I6, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I6, F16{}, F16{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_permute_scale);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_permute_scale_impl.hpp"
#include "profiler/profile_permute_scale_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
......@@ -15,15 +15,32 @@ class TestPermute : public ::testing::Test
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
void Run()
constexpr bool skip_case()
{
std::vector<std::vector<ck::index_t>> lengths = {
{4, 2, 1, 8}, {1, 1, 1, 1}, {16, 8, 32, 64}, {32, 64, 128, 128}};
#ifndef CK_ENABLE_FP16
if constexpr(ck::is_same_v<ADataType, F16> || ck::is_same_v<BDataType, F16>)
{
return true;
}
#endif
#ifndef CK_ENABLE_FP32
if constexpr(ck::is_same_v<ADataType, F32> || ck::is_same_v<BDataType, F32>)
{
return true;
}
#endif
return false;
}
for(auto length : lengths)
template <ck::index_t NDims>
void Run(std::vector<ck::index_t> lengths,
std::vector<ck::index_t> input_strides,
std::vector<ck::index_t> output_strides)
{
if(!skip_case())
{
bool success =
ck::test_permute_scale_impl<ADataType, BDataType, 4>(true, 2, false, false, length);
bool success = ck::profiler::profile_permute_scale_impl<ADataType, BDataType, NDims>(
true, 2, false, false, lengths, input_strides, output_strides);
EXPECT_TRUE(success);
}
}
......@@ -32,5 +49,52 @@ class TestPermute : public ::testing::Test
using KernelTypes = ::testing::Types<std::tuple<F16, F16>, std::tuple<F32, F32>>;
TYPED_TEST_SUITE(TestPermute, KernelTypes);
TYPED_TEST(TestPermute, Test_FP16) { this->Run(); }
TYPED_TEST(TestPermute, Test_FP32) { this->Run(); }
TYPED_TEST(TestPermute, Test1D)
{
constexpr ck::index_t NumDims = 1;
this->template Run<NumDims>({8}, {1}, {2});
this->template Run<NumDims>({8}, {2}, {1});
this->template Run<NumDims>({1}, {1}, {1});
}
TYPED_TEST(TestPermute, Test2D)
{
constexpr ck::index_t NumDims = 2;
this->template Run<NumDims>({8, 4}, {4, 1}, {1, 8});
this->template Run<NumDims>({8, 4}, {1, 8}, {4, 1});
this->template Run<NumDims>({1, 1}, {1, 1}, {1, 1});
}
TYPED_TEST(TestPermute, Test3D)
{
constexpr ck::index_t NumDims = 3;
this->template Run<NumDims>({2, 4, 4}, {16, 4, 1}, {1, 2, 8});
this->template Run<NumDims>({2, 4, 4}, {1, 2, 8}, {16, 4, 1});
this->template Run<NumDims>({1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
TYPED_TEST(TestPermute, Test4D)
{
constexpr ck::index_t NumDims = 4;
this->template Run<NumDims>({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32});
this->template Run<NumDims>({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1});
this->template Run<NumDims>({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1});
}
TYPED_TEST(TestPermute, Test5D)
{
constexpr ck::index_t NumDims = 5;
this->template Run<NumDims>({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128});
this->template Run<NumDims>({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1});
this->template Run<NumDims>({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
}
TYPED_TEST(TestPermute, Test6D)
{
constexpr ck::index_t NumDims = 6;
this->template Run<NumDims>(
{2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512});
this->template Run<NumDims>(
{2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1});
this->template Run<NumDims>({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1});
}
add_gtest_executable(test_layout test_layout.cpp)
target_link_libraries(test_layout PRIVATE utility)
add_gtest_executable(test_tensor test_tensor.cpp)
target_link_libraries(test_tensor PRIVATE utility)
add_gtest_executable(test_copy test_copy.cpp)
target_link_libraries(test_copy PRIVATE utility)
add_gtest_executable(test_partition test_partition.cpp)
target_link_libraries(test_partition PRIVATE utility)
add_custom_target(test_wrapper)
add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp)
target_link_libraries(test_wrapper_layout PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_layout)
add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp)
target_link_libraries(test_wrapper_tensor PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_tensor)
add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp)
target_link_libraries(test_wrapper_copy PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_copy)
add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp)
target_link_libraries(test_wrapper_partition PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_partition)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
add_gtest_executable(test_gemm test_gemm.cpp)
target_link_libraries(test_gemm PRIVATE utility)
add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp)
target_link_libraries(test_wrapper_gemm PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_gemm)
endif()
......@@ -20,23 +20,25 @@
template <typename InputTensor,
typename OutputTensor,
typename BlockShape,
typename ThreadLayoutShape,
typename ThreadLayout,
bool UseOptimizedCopy>
__global__ void TestCopyDevice(const InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
const ThreadLayout thread_layout)
{
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
p_shared, ck::wrapper::make_layout(tile_shape));
const auto block_idx = static_cast<ck::index_t>(blockIdx.x);
const auto block_idxs =
ck::make_tuple(static_cast<ck::index_t>(blockIdx.x), static_cast<ck::index_t>(blockIdx.y));
// Get local tiles for global memory
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
const auto input_local_tile =
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
const auto output_local_tile =
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
// Get partition per thread
const auto input_local_partition =
......@@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor,
// Allocate VGPR
auto tensor_vgpr =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
layout(lds_local_partition));
ck::wrapper::make_layout(shape(lds_local_partition)));
// Perform copy
if constexpr(UseOptimizedCopy)
......@@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS()
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{});
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}));
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const ck::index_t grid_size = ck::math::integer_divide_ceil(
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape));
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(
ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(
ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape));
const auto kernel = TestCopyDevice<decltype(input_tensor_global),
decltype(output_tensor_global),
......@@ -112,7 +117,7 @@ void PerformCopyGlobalToGlobalViaLDS()
UseOptimizedCopy>;
launch_and_time_kernel(StreamConfig{},
kernel,
dim3(grid_size),
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
input_tensor_global,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
template <typename DataType>
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
a_m_k.mData = a_data;
b_k_n.mData = b_data;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
}
template <bool DoPad, typename Layout, typename PaddingDims>
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
{
if constexpr(DoPad)
{
return ck::wrapper::pad(layout, padding_dims);
}
else
{
return layout;
}
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout,
bool DoPadding>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
constexpr auto K1 = GemmTraits::K1;
constexpr auto K0PerBlock = KPerBlock / K1;
const auto K0 = ck::math::integer_divide_ceil(K, K1);
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
auto a_padded_global_layout =
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
auto b_padded_global_layout =
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
auto c_padded_global_layout =
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
// Reshape from M,K to K0,M,K1
const auto reshaped_dims_idxs =
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
auto a_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto b_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_padded_global_layout);
// Add extra M and N
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, MPerBlock, K1),
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, NPerBlock, K1),
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
constexpr ck::index_t vector_dim = 2;
auto c_global_local_tile =
ck::wrapper::make_local_tile(c_global_tensor,
tile_shape_k0_m_n_k1,
block_idxs,
make_tuple(ck::wrapper::slice(K0PerBlock),
ck::Number<1>{},
ck::Number<1>{},
ck::wrapper::slice(K1)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
auto a_lds_tensor_local_partition =
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
auto b_lds_tensor_local_partition =
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
const auto k_slice =
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
ck::wrapper::slice(),
ck::wrapper::slice());
auto local_tile = ck::wrapper::make_local_tile(
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
};
auto a_global_local_partition = make_global_partition(
a_global_tensor,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
0);
auto b_global_local_partition = make_global_partition(
b_global_tensor,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
0);
// (row-major vgpr layout)
auto a_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(a_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
auto b_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(b_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
a_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
b_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
b_lds_tensor_local_partition);
const ck::index_t num_loop =
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
if(num_loop > 1)
{
ck::index_t i = 0;
do
{
auto a_global_local_partition_i = make_global_partition(
a_global_tensor,
make_tuple(
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
i + 1);
auto b_global_local_partition_i = make_global_partition(
b_global_tensor,
make_tuple(
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
i + 1);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_partition_i, a_vgpr_tensor);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_partition_i, b_vgpr_tensor);
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_vgpr_tensor, a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_vgpr_tensor, b_lds_tensor_local_partition);
++i;
} while(i < (num_loop - 1));
}
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
bool DoPadding,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data());
b_mem.ToDevice(b_data.data());
c_mem.SetZero();
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
std::vector<DataType> c_data(M * N);
c_mem.FromDevice(c_data.data());
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float)
{
using DataType = float;
// (dim1, dim2, dim0 thread layout)
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8)
{
using DataType = int8_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType,
ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1,
16,
false>(512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Half)
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
{
using DataType = float;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......
......@@ -29,8 +29,11 @@ TEST(TestPartition, LocalPartition)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
// row-major thread layout
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{}));
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
const auto thread_projection =
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
......@@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile)
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
const auto block_projection =
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto num_blocks =
const auto grid_shape =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape));
std::vector<ck::index_t> block_idxs(ck::wrapper::size(num_blocks));
std::iota(block_idxs.begin(), block_idxs.end(), 0);
std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t, ck::index_t>> block_idxs;
for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++)
{
for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++)
{
for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++)
{
block_idxs.emplace_back(i, j, k, 0);
}
}
}
for(auto block_idx : block_idxs)
{
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) *
ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides);
block_idx /= ck::wrapper::size<2>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) *
expected_tile_first_val += ck::wrapper::size<1>(block_idx) *
ck::wrapper::size<1>(block_shape) *
ck::wrapper::size<1>(strides);
block_idx /= ck::wrapper::size<1>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) *
expected_tile_first_val += ck::wrapper::size<0>(block_idx) *
ck::wrapper::size<0>(block_shape) *
ck::wrapper::size<0>(strides);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment