"tests/vscode:/vscode.git/clone" did not exist on "2c60f7d14e5297a61301c8bb2698717c244d3e43"
Commit 10cabc2d authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_splitk

parents baf68688 091570f5
...@@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perlayer_quantization_example.inc" #include "run_conv2d_fwd_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op);
}
...@@ -167,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -167,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_bias_relu_perchannel_quantization_example() int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -189,7 +189,6 @@ int run_conv2d_fwd_bias_relu_perchannel_quantization_example() ...@@ -189,7 +189,6 @@ int run_conv2d_fwd_bias_relu_perchannel_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_bias_relu_perlayer_quantization_example() int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example() ...@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_perchannel_quantization_example() int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example() ...@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_perlayer_quantization_example() int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = false; bool time_kernel = false;
...@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example() ...@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
// FMA instruction // FMA instruction
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -7,10 +7,30 @@ namespace ck { ...@@ -7,10 +7,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp(float requantScale, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
...@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp ...@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp
Activation activationOp_; Activation activationOp_;
}; };
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float y_fp32 = ck::type_convert<float>(x);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as // Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc // relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Activation_Mul2_Clamp struct Activation_Mul2_Clamp
{ {
...@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp ...@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp
}; };
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
...@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp ...@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp
}; };
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc // For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) // Convolution + Activation (non piecewise linear function)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) // Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{ {
} }
...@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp ...@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x + bias); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32; y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float requantScale1_; __host__ __device__ constexpr void
float requantScale2_; operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -320,6 +320,19 @@ struct Sigmoid ...@@ -320,6 +320,19 @@ struct Sigmoid
int32_t divider_ = 1; int32_t divider_ = 1;
}; };
struct TanH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
};
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if defined(__gfx90a__)
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
using ABDataTypeAdjusted = ABDataType;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABDataType, ABDataType,
ABDataType, ABDataTypeAdjusted,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ABDataType, ABDataType,
ABDataType, ABDataTypeAdjusted,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataTypeAdjusted,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -166,15 +166,12 @@ __global__ void ...@@ -166,15 +166,12 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size = __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared,
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -183,16 +180,16 @@ __global__ void ...@@ -183,16 +180,16 @@ __global__ void
c_element_op, c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc; ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc; ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = c_block_cluster_adaptor; ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
// M0/M1/M1Padding // M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{}; static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{}; static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{};
...@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk); math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
...@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared), static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -58,16 +58,16 @@ __global__ void ...@@ -58,16 +58,16 @@ __global__ void
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
...@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check // sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
...@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize()); b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......
...@@ -948,7 +948,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -948,7 +948,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// clang-format off // clang-format off
str << "GemmXdlSplitKCShuffle_" str << "GemmXdlSplitKCShuffle_"
<< getGemmSpecializationString(GemmSpec) << "_" << getGemmSpecializationString(GemmSpec) << "_"
<< LStr<ALayout>::Get() << LStr<BLayout>::Get() << LStr<CLayout>::Get() << "_" << std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< "_"
<< "B" << BlockSize << "_" << "B" << BlockSize << "_"
<< "Vec" << ABlockTransferSrcScalarPerVector << "x" << "Vec" << ABlockTransferSrcScalarPerVector << "x"
<< BBlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << "x"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -1008,6 +1008,60 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1008,6 +1008,60 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
...@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; ...@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); };
static inline __host__ double tanh(double x) { return std::tanh(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); }; static inline __device__ float abs(float x) { return ::abs(x); };
...@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); ...@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
static inline __device__ double tanh(double x) { return ::tanh(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
...@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>; ...@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu; using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
...@@ -102,6 +103,10 @@ template <typename Activation> ...@@ -102,6 +103,10 @@ template <typename Activation>
using Add_Activation_Mul_Clamp = using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>; ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Mul_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<Activation>;
template <typename Activation> template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>; using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
...@@ -109,6 +114,10 @@ template <typename Activation> ...@@ -109,6 +114,10 @@ template <typename Activation>
using Add_Activation_Mul2_Clamp = using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>; ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Mul2_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void> template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
......
...@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( ...@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( ...@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
}; };
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment