Commit bbe74503 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 8b76b832 f53ede26
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -27,6 +27,8 @@ enum struct MfmaInstr ...@@ -27,6 +27,8 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16, mfma_f32_16x16x8bf16,
mfma_i32_32x32x8i8, mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_i32_32x32x16i8,
mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64 mfma_f64_16x16x4f64
}; };
...@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8> ...@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x16i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_32x32x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_16x16x32i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <> template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{ {
...@@ -524,17 +570,29 @@ struct MfmaSelector ...@@ -524,17 +570,29 @@ struct MfmaSelector
#endif #endif
} }
#if defined(CK_USE_AMD_MFMA_GFX940)
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x32i8;
}
#else
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() static constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() static constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x16i8; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
......
...@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> ...@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;
template <>
struct intrin_mfma_i32_32x32x16i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x32i8;
template <>
struct intrin_mfma_i32_16x16x32i8<16, 16>
{
template <class FloatC>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64; struct intrin_mfma_f64_16x16x4f64;
...@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{ {
#ifdef __gfx90a__ #if defined(__gfx90a__) || defined(__gfx940__)
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else #else
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -898,6 +898,8 @@ struct vector_type<T, 256> ...@@ -898,6 +898,8 @@ struct vector_type<T, 256>
} }
}; };
using int64_t = long;
// fp64 // fp64
using double2_t = typename vector_type<double, 2>::type; using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type; using double4_t = typename vector_type<double, 4>::type;
...@@ -974,42 +976,94 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -974,42 +976,94 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
uint32_t int32; uint32_t int32;
} u = {x}; } u = {x};
if(~u.int32 & 0x7f800000) // When the exponent bits are not all 1s, then the value is zero, normal,
{ // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// When the exponent bits are not all 1s, then the value is zero, normal, // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus // This causes the bfloat16's mantissa to be incremented by 1 if the 16
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). // least significant bits of the float mantissa are greater than 0x8000,
// This causes the bfloat16's mantissa to be incremented by 1 if the 16 // or if they are equal to 0x8000 and the least significant bit of the
// least significant bits of the float mantissa are greater than 0x8000, // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// or if they are equal to 0x8000 and the least significant bit of the // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when // has the value 0x7f, then incrementing it causes it to become 0x00 and
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already // the exponent is incremented by one, which is the next higher FP value
// has the value 0x7f, then incrementing it causes it to become 0x00 and // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// the exponent is incremented by one, which is the next higher FP value // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// to a normal value with an exponent of 0x01 and a mantissa of 0x00. // incrementing it causes it to become an exponent of 0xFF and a mantissa
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, // of 0x00, which is Inf, the next higher value to the unrounded value.
// incrementing it causes it to become an exponent of 0xFF and a mantissa bool flag0 = ~u.int32 & 0x7f800000;
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even // When all of the exponent bits are 1, the value is Inf or NaN.
} // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
else if(u.int32 & 0xffff) // mantissa bit. Quiet NaN is indicated by the most significant mantissa
{ // bit being 1. Signaling NaN is indicated by the most significant
// When all of the exponent bits are 1, the value is Inf or NaN. // mantissa bit being 0 but some other bit(s) being 1. If any of the
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero // lower 16 bits of the mantissa are 1, we set the least significant bit
// mantissa bit. Quiet NaN is indicated by the most significant mantissa // of the bfloat16 mantissa, in order to preserve signaling NaN in case
// bit being 1. Signaling NaN is indicated by the most significant // the bfloat16's mantissa bits are all 0.
// mantissa bit being 0 but some other bit(s) being 1. If any of the bool flag1 = !flag0 && (u.int32 & 0xffff);
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
// the bloat16's mantissa bits are all 0. u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
...@@ -135,6 +135,28 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h ...@@ -135,6 +135,28 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
c); c);
} }
template <>
__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
{
c += type_convert<int32_t>(a) * type_convert<int32_t>(b);
}
template <>
__device__ void
inner_product<int8x2_t, int8x2_t, int32_t>(const int8x2_t& a, const int8x2_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I0],
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I0],
c);
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I1],
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I1],
c);
}
template <> template <>
__device__ void __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
......
...@@ -168,6 +168,10 @@ __device__ double exp<double>(double x) ...@@ -168,6 +168,10 @@ __device__ double exp<double>(double x)
return exp(x); return exp(x);
} }
static inline __host__ float exp(float x) { return std::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
......
...@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; ...@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); };
static inline __host__ double tanh(double x) { return std::tanh(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); }; static inline __device__ float abs(float x) { return ::abs(x); };
...@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); ...@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
static inline __device__ double tanh(double x) { return ::tanh(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
...@@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>; ...@@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>;
using F16_Tuple = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F64_Tuple = ck::Tuple<F64>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>; using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>; using I32_F32_Tuple = ck::Tuple<I32, F32>;
...@@ -85,6 +86,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>; ...@@ -85,6 +86,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu; using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
...@@ -94,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu; ...@@ -94,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
...@@ -102,6 +105,10 @@ template <typename Activation> ...@@ -102,6 +105,10 @@ template <typename Activation>
using Add_Activation_Mul_Clamp = using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>; ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Mul_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<Activation>;
template <typename Activation> template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>; using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
...@@ -109,6 +116,10 @@ template <typename Activation> ...@@ -109,6 +116,10 @@ template <typename Activation>
using Add_Activation_Mul2_Clamp = using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>; ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Mul2_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void> template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
......
...@@ -19,6 +19,7 @@ namespace tensor_operation { ...@@ -19,6 +19,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// float
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2, 2,
...@@ -67,6 +68,55 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn ...@@ -67,6 +68,55 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
// double
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
// Contraction + Bilinear // Contraction + Bilinear
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -118,6 +168,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -118,6 +168,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
} }
} }
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<DDataType, double> && is_same_v<EDataType, double>)
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
op_ptrs);
}
}
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -19,6 +19,7 @@ namespace tensor_operation { ...@@ -19,6 +19,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// float
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2, 2,
...@@ -67,6 +68,55 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc ...@@ -67,6 +68,55 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough, PassThrough,
Scale>>>& instances); Scale>>>& instances);
// double
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
// Contraction + Scale // Contraction + Scale
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -117,6 +167,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -117,6 +167,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
} }
} }
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<EDataType, double>)
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
op_ptrs);
}
}
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
...@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, BF16,
int8_t, BF16,
Empty_Tuple, Empty_Tuple,
int8_t, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC, NHWGC,
...@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
...@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> && else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>) is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
...@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
// no instance add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
...@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
// no instance add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
// no instance
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
......
...@@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( ...@@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -109,11 +161,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -109,11 +161,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&);
// FP32
void add_device_normalization_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&);
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Col, Col, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Row, Row, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Row, Col, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Col, Col, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Row, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
// Layout(A, B, C) = [Row, Col, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -17,14 +17,14 @@ namespace tensor_operation { ...@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -34,12 +34,12 @@ void add_device_conv2d_bias_perchannel_quantization_int8_instances( ...@@ -34,12 +34,12 @@ void add_device_conv2d_bias_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -49,6 +49,70 @@ void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( ...@@ -49,6 +49,70 @@ void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -90,17 +154,84 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -90,17 +154,84 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
}
} }
} }
......
...@@ -17,14 +17,14 @@ namespace tensor_operation { ...@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_bias_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -34,12 +34,12 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances( ...@@ -34,12 +34,12 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -49,6 +49,68 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( ...@@ -49,6 +49,68 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_Tuple,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -90,17 +152,84 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -90,17 +152,84 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perlayer_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
}
} }
} }
......
...@@ -17,13 +17,13 @@ namespace tensor_operation { ...@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_perchannel_quantization_int8_instances( void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
...@@ -33,12 +33,42 @@ void add_device_conv2d_perchannel_quantization_int8_instances( ...@@ -33,12 +33,42 @@ void add_device_conv2d_perchannel_quantization_int8_instances(
Activation_Mul2_Clamp<PassThrough>>>>& Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_Tuple,
NHWGK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
GK_Tuple,
NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
...@@ -89,17 +119,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -89,17 +119,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_perchannel_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_relu_perchannel_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs);
}
} }
} }
......
...@@ -17,13 +17,13 @@ namespace tensor_operation { ...@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_perlayer_quantization_int8_instances( void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -33,12 +33,42 @@ void add_device_conv2d_perlayer_quantization_int8_instances( ...@@ -33,12 +33,42 @@ void add_device_conv2d_perlayer_quantization_int8_instances(
Activation_Mul_Clamp<PassThrough>>>>& Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -87,16 +117,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -87,16 +117,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>) is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_perlayer_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_relu_perlayer_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs);
}
} }
} }
......
...@@ -47,7 +47,9 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16 ...@@ -47,7 +47,9 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
#if CK_WORKAROUND_SWDEV_388832
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
#endif
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
......
...@@ -47,7 +47,9 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ ...@@ -47,7 +47,9 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
#if CK_WORKAROUND_SWDEV_388832
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
#endif
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment