Commit 930b2872 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

best performing kernel for GEMV codex problem with M=1 with inverted B matrix

parents a1e17d18 a4f72a31
......@@ -419,5 +419,200 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
}
};
#endif
#if defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8;
template <>
struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8bf8;
template <>
struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8;
template <>
struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8bf8;
template <>
struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8;
template <>
struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8f8;
template <>
struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
#endif
} // namespace ck
#endif
......@@ -140,10 +140,36 @@ struct DynamicBuffer
}
else if constexpr(Op == InMemoryDataOperationEnum::Add)
{
auto tmp = this->template Get<X>(i, is_valid_element);
this->template Set<X>(i, is_valid_element, x + tmp);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
auto tmp = this->template Get<X>(i, is_valid_element);
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
// handle bfloat addition
if constexpr(is_same_v<scalar_t, bhalf_t>)
{
if constexpr(is_scalar_type<X>::value)
{
// Scalar type
auto result =
type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
this->template Set<X>(i, is_valid_element, result);
}
else
{
// Vector type
constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
const vector_type<scalar_t, vector_size> a_vector{tmp};
const vector_type<scalar_t, vector_size> b_vector{x};
static_for<0, vector_size, 1>{}([&](auto idx) {
auto result = type_convert<scalar_t>(
type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
this->template Set<scalar_t>(i + idx, is_valid_element, result);
});
}
}
else
{
this->template Set<X>(i, is_valid_element, x + tmp);
}
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
} // namespace ck
......@@ -897,3 +897,14 @@ template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
using S = ck::Sequence<Is...>;
os << "{";
ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}(
[&](auto i) { os << S::At(i).value << ", "; });
os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os;
}
......@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsTuple() { return true; }
};
template <>
......
......@@ -85,19 +85,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -105,33 +92,20 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -139,20 +113,14 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#endif
}
#endif
......@@ -161,19 +129,6 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -181,33 +136,20 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -215,20 +157,14 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#endif
}
#endif
......@@ -298,47 +234,30 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
#endif
......@@ -347,38 +266,21 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
......@@ -388,7 +290,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
/**
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template <ck::index_t NDimSpatial,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceColumnToImage : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
filter_spatial_lengths_{filter_spatial_lengths}
{
initOutputSpatialLengths();
}
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
in_right_pads_[i] - x_eff) /
conv_strides_[i] +
1);
}
}
};
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceColumnToImage::Argument;
float Run(const Argument& arg)
{
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{
float v_in = ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi));
arg.output_(0, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.output_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.output_.GetLengths()[4])
{
float v_in =
ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
}
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho *
arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(
wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.output_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.output_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.output_.GetLengths()[5])
{
float v_in = ck::type_convert<float>(
arg.input_(row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
}
}
}
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
return true;
}
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
{
return false;
}
if(G != 1)
{
return false;
}
return true;
}
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return Argument{input,
output,
filter_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceColumnToImage"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename ComputeTypeA = OutDataType,
typename ComputeTypeB = InDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator
{
......@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_out;
float v_in;
ComputeTypeA v_out;
ComputeTypeB v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
......@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in;
v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
}
}
}
......@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_out;
float v_in;
ComputeTypeA v_out;
ComputeTypeB v_in;
arg.out_element_op_(
v_out,
......@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in;
v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
}
}
}
......@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{
float v_out;
float v_in;
ComputeTypeA v_out;
ComputeTypeB v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
......@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in;
v_acc +=
type_convert<float>(v_out) * type_convert<float>(v_in);
}
}
}
......
......@@ -14,27 +14,27 @@ namespace ck {
namespace tensor_operation {
namespace host {
//
// @brief Reference implementation for forward convolution.
//
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
// @tparam InElementwiseOperation Functor for input tensor elementwise
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NDimSpatial Number of spatial dimensions.
//
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
///
/// @brief Reference implementation for forward convolution.
///
/// @paragraph
/// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
/// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
/// as long as dimensions in tensor descriptor is in GNCHW order
///
/// @tparam InDataType Input tensor data type.
/// @tparam WeiDataType Weights tensor data type.
/// @tparam OutDataType Output tensor data type.
/// @tparam InElementwiseOperation Functor for input tensor elementwise
/// operation.
/// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
/// operation.
/// @tparam NDimSpatial Number of spatial dimensions.
///
/// input descriptor in [G, N, C, Do, Ho, Wo] order
/// weight descriptor in [G, K, C, Z, Y, X] order
/// output descriptor in [G, N, K, Di, Hi, Wi] order
/// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
......
......@@ -21,7 +21,8 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputType = ADataType>
typename ComputeTypeA = ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
......@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k)
{
ComputType v_a;
ComputType v_b;
ComputeTypeA v_a;
ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
......
......@@ -18,16 +18,18 @@ namespace host {
/**
* \brief Reference implementation for image to column.
*
* Tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam InputLayout Input Layout.
* \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template <ck::index_t NDimSpatial,
typename InputLayout,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
......@@ -240,8 +242,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>))
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
......
......@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType,
typename DxDataType,
typename DyDataType,
......@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
......@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType,
typename YDataType,
typename AccDataType,
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
......@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,38 +16,38 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
#endif
template <typename XDataType,
typename YDataType,
typename ScaleDataType,
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>)
......@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
// Image to Column
// nhwc, 1d
void add_device_image_to_column_nwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// nhwc, 2d
void add_device_image_to_column_nhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// nhwc, 3d
void add_device_image_to_column_ndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// Column to Image
// nhwc, 1d
void add_device_column_to_image_nwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// nhwc, 2d
void add_device_column_to_image_nhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// nhwc, 3d
void add_device_column_to_image_ndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ColumnToImage>>>&
instances);
template <ck::index_t NumDimSpatial,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename ConvTensorRearrangeOp>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceConvTensorRearrange<NumDimSpatial,
ImageLayout,
InDataType,
OutDataType,
ConvTensorRearrangeOp>>
{
using DeviceOp = DeviceConvTensorRearrange<NumDimSpatial,
ImageLayout,
InDataType,
OutDataType,
ConvTensorRearrangeOp>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ConvTensorRearrangeOp, ImageToColumn>)
{
if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, GNWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_ndhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_ndhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_ndhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_ndhwc_3d_i8_instances(op_ptrs);
}
}
}
else if constexpr(is_same_v<ConvTensorRearrangeOp, ColumnToImage>)
{
if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, GNWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_ndhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_ndhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_ndhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_ndhwc_3d_i8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using namespace ck::conv_tensor_rearrange_op;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_bf16_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 64, 64, 64, S<8, 8>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 128, 64, 128, S<8, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 8>
// clang-format on
>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_f16_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 64, 64, 64, S<8, 8>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 128, 64, 128, S<8, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 8>
// clang-format on
>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_f32_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 256, 128, 128, S<16, 16>, 4>
// clang-format on
>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_i8_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 64, 64, S<8, 8>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 64, 128, S<8, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 256, 256, S<16, 16>, 16>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -13,6 +13,7 @@ namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using namespace ck::conv_tensor_rearrange_op;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
......@@ -28,17 +29,12 @@ using device_image_to_column_bf16_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 64, 64, S<8, 8>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 64, 128, S<8, 16>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 8>
......@@ -52,17 +48,13 @@ using device_image_to_column_f16_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 64, 64, S<8, 8>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 64, 128, S<8, 16>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 8>
......@@ -76,15 +68,11 @@ using device_image_to_column_f32_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 128, 128, S<16, 16>, 4>
// clang-format on
......@@ -97,17 +85,13 @@ using device_image_to_column_i8_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 64, 64, S<8, 8>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 64, 128, S<8, 16>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 8>,
......
......@@ -312,6 +312,23 @@ void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -505,6 +522,32 @@ struct DeviceOperationInstanceFactory<
#endif
}
}
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<ADataType, ck::f8_t> && is_same_v<BDataType, ck::f8_t> &&
is_same_v<CDataType, ck::f8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
......@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
......@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
// GEMM + Bilinear
template <typename ALayout,
typename BLayout,
......@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
......@@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
......@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
......@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#if defined CK_ENABLE_FP8
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
......@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#if defined CK_ENABLE_FP8
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
return op_ptrs;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment