Unverified Commit 408534d4 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents a8efb3f0 da214a5a
......@@ -839,7 +839,7 @@ inline __device__ T rcp(T x)
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
return ck::type_convert<T>(__ocml_exp_f32(ck::type_convert<float>(x)));
};
template <>
......@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x)
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
return __ocml_exp_f32(x);
};
template <>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -52,11 +52,19 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ + b_);
}
};
struct SquaredAdd
......@@ -104,11 +112,19 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ * b_);
}
};
struct Max
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
namespace ck {
// Define the common macro for gfx94x models
......@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
......@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x)
};
CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
bfloat16_t exp(bfloat16_t x)
{
return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
......
......@@ -835,7 +835,7 @@ CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
......@@ -860,7 +860,7 @@ CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
......
......@@ -374,7 +374,7 @@ half_t sqrt(half_t x)
};
CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
half_t exp(half_t x) { return static_cast<half_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
......
......@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
......
......@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
......
......@@ -231,7 +231,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
......@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<long_index_t> filter_spatial_lengths_;
std::vector<long_index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
for(ck::long_index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
......@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
const long_index_t G = arg.output_.GetLengths()[0];
const long_index_t N = arg.output_.GetLengths()[1];
const long_index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) {
for(long_index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
index_t column = 0;
long_index_t row = n * Wo + wo;
long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
......@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
const long_index_t Ho = arg.output_spatial_lengths_[0];
const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
for(long_index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
for(long_index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
......@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
const long_index_t Do = arg.output_spatial_lengths_[0];
const long_index_t Ho = arg.output_spatial_lengths_[1];
const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
for(long_index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
for(long_index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
for(long_index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho *
......@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2];
++x)
{
auto wi =
static_cast<ck::long_index_t>(
......@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
......@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2];
const ck::long_index_t G = arg.output_.GetLengths()[0];
const ck::long_index_t N = arg.output_.GetLengths()[1];
const ck::long_index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
const long_index_t NDoHoWo =
N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
const long_index_t CZYX =
C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
......@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads)
{
return Argument{input,
output,
......
......@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
......@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
......@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
......@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<ck::long_index_t> conv_strides_;
std::vector<ck::long_index_t> conv_dilations_;
std::vector<ck::long_index_t> in_left_pads_;
std::vector<ck::long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
......@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
......@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
......@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<long_index_t> filter_spatial_lengths_;
std::vector<long_index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
......@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
......@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2];
const long_index_t G = arg.input_.GetLengths()[0];
const long_index_t N = arg.input_.GetLengths()[1];
const long_index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo;
index_t column = 0;
const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) {
long_index_t row = n * Wo + wo;
long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
......@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
const long_index_t Ho = arg.output_spatial_lengths_[0];
const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
......@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
const long_index_t Do = arg.output_spatial_lengths_[0];
const long_index_t Ho = arg.output_spatial_lengths_[1];
const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
......@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.input_.GetLengths()[0];
const ck::index_t N = arg.input_.GetLengths()[1];
const ck::index_t C = arg.input_.GetLengths()[2];
const ck::long_index_t G = arg.input_.GetLengths()[0];
const ck::long_index_t N = arg.input_.GetLengths()[1];
const ck::long_index_t C = arg.input_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
const long_index_t NDoHoWo =
N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
const long_index_t CZYX =
C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
......@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads)
{
return Argument{input,
output,
......
......@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename A0DataType,
typename A1DataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD_ABScale<
ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
128,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
128,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -315,7 +315,7 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
......@@ -416,6 +416,57 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_ins
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
template <typename ADataType,
typename BDataType,
......@@ -596,7 +647,7 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{
......@@ -653,6 +704,33 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using F8 = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename OutElementOp>
using device_grouped_conv_fwd_xdl_binary_outelementop_f8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<F32>, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment