Unverified Commit e6bb1dd7 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Merge branch 'develop' into feature/check-window-lengths

parents 9d6a3704 ab250afd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convinvscale<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths,
const std::size_t& ds_size)
{
// G * N * C * <output spatial lengths product> * (2 * K * <filter spatial lengths product> +
// <number of scale factors>)
ck::index_t G = weights_lengths[0];
ck::index_t N = output_lengths[1];
ck::index_t K = weights_lengths[1];
ck::index_t C = weights_lengths[2];
return G * N * C *
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
(static_cast<std::size_t>(2) * K *
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) +
ds_size);
}
template <typename InDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetInputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& input_lengths)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * std::accumulate(std::begin(input_lengths),
std::end(input_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename WeiDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetWeightByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename OutDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetOutputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
}
template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
ck::index_t NumNonSpatialDim = 3,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool run_grouped_conv_fwd_convscale(
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
{
std::size_t in_mem_size = GetInputByte<InDataType, NumDimSpatial>(in_lengths);
std::size_t wei_mem_size = GetWeightByte<WeiDataType, NumDimSpatial>(wei_lengths);
std::size_t out_mem_size = GetOutputByte<OutDataType, NumDimSpatial>(out_lengths);
SimpleDeviceMem in(in_mem_size);
SimpleDeviceMem wei(wei_mem_size);
SimpleDeviceMem out(out_mem_size);
float scale_in = float(std::rand()) / float(RAND_MAX);
float scale_wei = float(std::rand()) / float(RAND_MAX);
float scale_out = float(std::rand()) / float(RAND_MAX);
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_strides;
in_strides.fill(0);
wei_strides.fill(0);
out_strides.fill(0);
in_strides.back() = 1;
wei_strides.back() = 1;
out_strides.back() = 1;
std::partial_sum(rbegin(in_lengths),
std::prev(rend(in_lengths)),
std::next(rbegin(in_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(wei_lengths),
std::prev(rend(wei_lengths)),
std::next(rbegin(wei_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(out_lengths),
std::prev(rend(out_lengths)),
std::next(rbegin(out_strides)),
std::multiplies<>{});
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths));
std::rotate(rbegin(in_lengths),
std::next(rbegin(in_lengths)),
std::next(rbegin(in_lengths), NumDimSpatial + 1));
std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides));
std::rotate(rbegin(in_strides),
std::next(rbegin(in_strides)),
std::next(rbegin(in_strides), NumDimSpatial + 1));
std::rotate(rbegin(wei_lengths),
std::next(rbegin(wei_lengths)),
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
std::rotate(rbegin(wei_strides),
std::next(rbegin(wei_strides)),
std::next(rbegin(wei_strides), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths));
std::rotate(rbegin(out_lengths),
std::next(rbegin(out_lengths)),
std::next(rbegin(out_lengths), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides));
std::rotate(rbegin(out_strides),
std::next(rbegin(out_strides)),
std::next(rbegin(out_strides), NumDimSpatial + 1));
std::array<ck::index_t, NumDimSpatial> conv_filter_strides;
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations;
std::array<ck::index_t, NumDimSpatial> input_left_pads;
std::array<ck::index_t, NumDimSpatial> input_right_pads;
conv_filter_strides.fill(1);
conv_filter_dilations.fill(1);
input_left_pads.fill(1);
input_right_pads.fill(1);
std::size_t ds_size = 3; // 3 element-wise scale multipliers
std::size_t flop = GetFlops<NumDimSpatial>(out_lengths, wei_lengths, ds_size);
std::size_t num_bytes =
in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
ConvScale,
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ConvScale{scale_in, scale_wei, scale_out});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return false;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ConvScale{scale_in, scale_wei, scale_out});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::bf8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = InDataType;
using BComputeDataType = AComputeDataType;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convscale<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::f8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::bf8_t;
using BComputeDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convscale<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convscale<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::bf8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::bf8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convscale<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths,
const std::size_t& ds_size)
{
// G * N * C * <output spatial lengths product> * (2 * K * <filter spatial lengths product> +
// <number of scale factors>)
ck::index_t G = weights_lengths[0];
ck::index_t N = output_lengths[1];
ck::index_t K = weights_lengths[1];
ck::index_t C = weights_lengths[2];
return G * N * C *
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
(static_cast<std::size_t>(2) * K *
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) +
ds_size);
}
template <typename InDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetInputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& input_lengths)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * std::accumulate(std::begin(input_lengths),
std::end(input_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename WeiDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetWeightByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename OutDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetOutputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
}
template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
ck::index_t NumNonSpatialDim = 3,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool run_grouped_conv_fwd_convscale_relu(
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
{
std::size_t in_mem_size = GetInputByte<InDataType, NumDimSpatial>(in_lengths);
std::size_t wei_mem_size = GetWeightByte<WeiDataType, NumDimSpatial>(wei_lengths);
std::size_t out_mem_size = GetOutputByte<OutDataType, NumDimSpatial>(out_lengths);
SimpleDeviceMem in(in_mem_size);
SimpleDeviceMem wei(wei_mem_size);
SimpleDeviceMem out(out_mem_size);
float scale_in = float(std::rand()) / float(RAND_MAX);
float scale_wei = float(std::rand()) / float(RAND_MAX);
float scale_out = float(std::rand()) / float(RAND_MAX);
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_strides;
in_strides.fill(0);
wei_strides.fill(0);
out_strides.fill(0);
in_strides.back() = 1;
wei_strides.back() = 1;
out_strides.back() = 1;
std::partial_sum(rbegin(in_lengths),
std::prev(rend(in_lengths)),
std::next(rbegin(in_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(wei_lengths),
std::prev(rend(wei_lengths)),
std::next(rbegin(wei_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(out_lengths),
std::prev(rend(out_lengths)),
std::next(rbegin(out_strides)),
std::multiplies<>{});
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths));
std::rotate(rbegin(in_lengths),
std::next(rbegin(in_lengths)),
std::next(rbegin(in_lengths), NumDimSpatial + 1));
std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides));
std::rotate(rbegin(in_strides),
std::next(rbegin(in_strides)),
std::next(rbegin(in_strides), NumDimSpatial + 1));
std::rotate(rbegin(wei_lengths),
std::next(rbegin(wei_lengths)),
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
std::rotate(rbegin(wei_strides),
std::next(rbegin(wei_strides)),
std::next(rbegin(wei_strides), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths));
std::rotate(rbegin(out_lengths),
std::next(rbegin(out_lengths)),
std::next(rbegin(out_lengths), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides));
std::rotate(rbegin(out_strides),
std::next(rbegin(out_strides)),
std::next(rbegin(out_strides), NumDimSpatial + 1));
std::array<ck::index_t, NumDimSpatial> conv_filter_strides;
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations;
std::array<ck::index_t, NumDimSpatial> input_left_pads;
std::array<ck::index_t, NumDimSpatial> input_right_pads;
conv_filter_strides.fill(1);
conv_filter_dilations.fill(1);
input_left_pads.fill(1);
input_right_pads.fill(1);
std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 elementwise Relu
std::size_t flop = GetFlops<NumDimSpatial>(out_lengths, wei_lengths, ds_size);
std::size_t num_bytes =
in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
ConvScaleRelu,
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ConvScaleRelu{scale_in, scale_wei, scale_out});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return false;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ConvScaleRelu{scale_in, scale_wei, scale_out});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convscale_relu<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -7,19 +7,23 @@
#include <initializer_list>
#include <vector>
#include "ck/utility/common_header.hpp"
// __gfx9__ defined in the above header via ck.hpp
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
struct SimpleDeviceMem
{
......@@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M,
int main(int argc, char* argv[])
{
bool is_supported = ck::is_xdl_supported();
if(!is_supported)
{
std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
......@@ -213,3 +225,4 @@ int main(int argc, char* argv[])
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
#endif
......@@ -181,4 +181,3 @@ int main(int argc, char* argv[])
{1, 1, 1} /*filter_dilations*/);
return 0;
}
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
......@@ -7,18 +7,21 @@
#include <initializer_list>
#include <vector>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/utility/common_header.hpp"
// __gfx9__ defined in the above header via ck.hpp
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
struct SimpleDeviceMem
{
......@@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M,
int main(int argc, char* argv[])
{
bool is_supported = ck::is_xdl_supported();
if(!is_supported)
{
std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
......@@ -305,3 +316,4 @@ int main(int argc, char* argv[])
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
#endif
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp"
......
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp"
......
......@@ -6,46 +6,36 @@ if (DTYPES)
add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8")
add_definitions(-DCK_ENABLE_INT8)
if(NOT DEFINED ${CK_ENABLE_INT8})
set(CK_ENABLE_INT8 "ON")
endif()
set(CK_ENABLE_INT8 "ON")
endif()
if (DTYPES MATCHES "fp8")
add_definitions(-DCK_ENABLE_FP8)
if(NOT DEFINED ${CK_ENABLE_FP8})
set(CK_ENABLE_FP8 "ON")
endif()
set(CK_ENABLE_FP8 "ON")
endif()
if (DTYPES MATCHES "bf8")
add_definitions(-DCK_ENABLE_BF8)
set(CK_ENABLE_BF8 "ON")
endif()
if (DTYPES MATCHES "fp16")
add_definitions(-DCK_ENABLE_FP16)
if(NOT DEFINED ${CK_ENABLE_FP16})
set(CK_ENABLE_FP16 "ON")
endif()
set(CK_ENABLE_FP16 "ON")
endif()
if (DTYPES MATCHES "fp32")
add_definitions(-DCK_ENABLE_FP32)
if(NOT DEFINED ${CK_ENABLE_FP32})
set(CK_ENABLE_FP32 "ON")
endif()
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
if(NOT DEFINED ${CK_ENABLE_FP64})
set(CK_ENABLE_FP64 "ON")
endif()
set(CK_ENABLE_FP64 "ON")
endif()
if (DTYPES MATCHES "bf16")
add_definitions(-DCK_ENABLE_BF16)
if(NOT DEFINED ${CK_ENABLE_BF16})
set(CK_ENABLE_BF16 "ON")
endif()
set(CK_ENABLE_BF16 "ON")
endif()
message("DTYPES macro set to ${DTYPES}")
else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES})
set(CK_ENABLE_ALL_DTYPES "ON")
endif()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON")
endif()
if (GPU_TARGETS)
......@@ -73,7 +63,8 @@ message(STATUS "Build with HIP ${hip_VERSION}")
# add all example subdir
file(GLOB dir_list LIST_DIRECTORIES true *)
FOREACH(subdir ${dir_list})
IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build"))
IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build")
AND (NOT "${subdir}" MATCHES ".vscode"))
add_subdirectory(${subdir})
ENDIF()
ENDFOREACH()
......@@ -2,7 +2,7 @@
#
# MIT License
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
# Copyright (c) 2017-2024 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......@@ -96,6 +96,7 @@ else()
-Wno-covered-switch-default
-Wno-unsafe-buffer-usage
-Wno-unused-lambda-capture
-Wno-nvcc-compat
)
else()
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
......
cmake_minimum_required(VERSION 3.16)
project(composable_kernel_host)
project(composable_kernel_host LANGUAGES CXX HIP)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
......@@ -12,24 +12,38 @@ find_package(ROCM)
include(ROCMInstallTargets)
include(ROCMTest)
add_compile_options(-std=c++17)
find_package(hip)
## HIP
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
set(CMAKE_HIP_EXTENSIONS ON)
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
# add include directories
include_directories(BEFORE
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/library/include
${HIP_INCLUDE_DIRS}
)
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${CK_ROOT}/include/ck/*.hpp)
${CK_ROOT}/include/ck/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
message(STATUS "RELATIVE: ${CK_ROOT}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
add_definitions(-std=c++17)
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
# TODO: Use object library
add_library(ck_host STATIC ${SOURCES})
target_link_libraries(ck_host PRIVATE ck_headers)
set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON)
set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON)
target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
......
......@@ -5,24 +5,27 @@
#include <unordered_map>
#include <vector>
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/stringutils.hpp"
using ck::host::Transform;
struct Emitters
{
// retrieve the hard-coded instances provided, template them, and then store them in a map
std::unordered_map<std::string, std::function<std::vector<std::string>()>> m;
template <class T>
void Register(const std::string& name)
void Register(const std::string& name, const std::string& prologue, const std::string& epilogue)
{
m[name] = [] {
auto configs = T::CreateOperations();
m[name] = [&] {
auto configs = T::CreateOperations(prologue, epilogue);
return Transform(configs, [](const auto& ops) { return ToTuple(ops); });
};
}
// takes in an operation instance and uses it to substitute the correct values into the template
template <class T>
static std::string ToTuple(const T& ops)
{
......@@ -31,6 +34,7 @@ struct Emitters
return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">";
}
// Join together all the strings in the map
std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); }
std::vector<std::string> List() const
......@@ -43,9 +47,38 @@ int main(int argc, const char* argv[])
{
std::string prog = argv[0];
std::vector<std::string> args(argv + 1, argv + argc);
// Specify problem type and problem size
ck::host::device_gemm_multiple_d::Problem prob;
prob.M = 1024;
prob.N = 1024;
prob.K = 1024;
// user provided fusion
std::string prologue = "";
std::string epilogue = R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};)";
// Load in operations into the Register
Emitters e;
e.Register<ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle>(
"DeviceGemmMultipleD_Xdl_CShuffle");
"DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue);
if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) {
return arg == "-h" or arg == "--help";
......@@ -64,6 +97,7 @@ int main(int argc, const char* argv[])
return 0;
}
// print out all the instances for the operation that was chosen at the command line
for(auto name : args)
std::cout << e.Emit(name) << std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
......@@ -14,10 +14,15 @@ namespace ck {
namespace host {
namespace device_gemm_multiple_d {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
static std::vector<std::vector<Operation_Xdl_CShuffle>> CreateOperations();
static std::vector<Operation_Xdl_CShuffle> CreateOperations(const Problem& prob);
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
DataType acc = DataType::Float;
......@@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string cde_elem_op = Bilinear;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDesc tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
// returns a templated instance
Solution ToSolution() const;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -12,11 +12,14 @@ namespace ck {
namespace host {
namespace device_gemm_multiple_d {
// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
// dimensions for GEMM operation
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
// layouts for tensors
bool TransA = false;
bool TransB = false;
bool TransE = false;
......@@ -29,9 +32,13 @@ struct Problem
std::string BElementOp = PassThrough;
std::string CDEElementOp = PassThrough;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_gemm_multiple_d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment