Commit 76f2b6cd authored by danyao12's avatar danyao12
Browse files

merge develop to attn-train-develop-qloop

parents 9b4c780a 1ee99dca
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -56,7 +56,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
......@@ -938,7 +939,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
......@@ -49,6 +50,14 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
......@@ -67,6 +76,30 @@ struct Add
};
};
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
template <>
__host__ __device__ void
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
float scale_;
};
struct Subtract
{
template <typename T>
......@@ -118,6 +151,13 @@ struct Bilinear
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
template <>
__host__ __device__ constexpr void
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
{
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
......@@ -241,43 +281,42 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float x = c + d;
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
FastGelu{}.template operator()<float, float>(e, x);
}
e = type_convert<E>(y);
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
static_assert(is_valid_param_type_v<D>);
const float x0_f = c + d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = GetFastGeLU(c + type_convert<float>(d));
e = type_convert<half_t>(x1_f);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
// Example:
//
// struct ExampleElementwiseOp
// {
......@@ -30,19 +30,6 @@ namespace element_wise {
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd
{
......@@ -173,40 +160,109 @@ struct AddAdd
};
// C = A * B
// E = (C + D0) x D1
struct AddMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (type_convert<half_t>(c) + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = (c + d0) * d1;
e = y;
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
const float x = c + d0 + d1;
FastGelu{}.template operator()<float, float>(e, x);
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const half_t x = c + d0 + d1;
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float x0_f = c + d0 + d1;
const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
float x1_f = 0;
e = type_convert<E>(y);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
{
const float x0_f =
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<int8_t>(x1_f);
}
};
......@@ -278,6 +334,40 @@ struct Normalize
double epsilon_;
};
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& variance,
const T3& gamma,
const T4& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
type_convert<T2>(gamma) +
type_convert<T2>(beta);
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
template <typename Y, typename X>
struct UnaryTypeConvert;
......
#pragma once
#include "ck/utility/data_type.hpp"
// #include "ck/utility/get_id.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation>
struct Activation_Mul_Clamp
{
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
......@@ -17,26 +38,66 @@ struct Activation_Mul_Clamp
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
__device__ constexpr void operator()(int32_t& y, const int32_t& x) const
{
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
__host__ constexpr void operator()(float& y, const float& x) const
{
// CAUSION - We might float in & float out in reference code
activationOp_(y, x);
y = math::clamp(requantScale_ * y, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float y_fp32 = ck::type_convert<float>(x);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation>
struct Activation_Mul2_Clamp
{
......@@ -51,13 +112,35 @@ struct Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__device__ constexpr void
operator()(int32_t& y, const int32_t& x, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
......@@ -72,6 +155,17 @@ struct Add_Activation_Mul_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
......@@ -92,15 +186,33 @@ struct Add_Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
......@@ -108,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32;
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale1_;
float requantScale2_;
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
Activation activationOp_;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
#if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThrough
{
template <typename Y, typename X>
......@@ -52,6 +57,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
......@@ -71,6 +82,36 @@ struct PassThrough
y = x;
}
#endif
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
{
y = type_convert<f8_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
{
y = type_convert<f8_t>(x);
}
};
struct UnaryConvert
......@@ -82,6 +123,40 @@ struct UnaryConvert
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
......@@ -95,6 +170,12 @@ struct Scale
y = scale_ * x;
};
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
__host__ __device__ auto Value() const { return scale_; }
float scale_;
......@@ -196,36 +277,83 @@ struct Relu
}
};
// Y = FastGelu(X)
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
struct FastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
y = x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
#if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
#else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
y = x * cdf;
}
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<half_t>(y_f);
}
};
......@@ -261,8 +389,36 @@ struct Sigmoid
y = 1 / (ck::type_convert<T>(1) + exp(-x));
};
};
int32_t divider_ = 1;
struct TanH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
};
float beta_ = 1.0f;
};
} // namespace element_wise
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_synchronization.hpp"
namespace ck {
template <typename GridwiseMultiblockBatchNormForward_,
typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_batchnorm_forward(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count,
int32_t* const __restrict__ p_control,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
y_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
mean_var_count_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
epsilon,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count,
p_control,
p_scale,
p_bias,
y_elementwise_op,
p_y,
updateMovingAverage,
averageFactor,
resultRunningMean,
resultRunningVariance,
saveMeanInvVariance,
resultSaveMean,
resultSaveInvVariance);
};
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseMultiblockBatchNormForward
{
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadwiseWelford1 =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using ThreadwiseWelford2 =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford1 = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using BlockwiseWelford2 = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count,
int32_t* const __restrict__ p_control,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
using ck::math::sqrt;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
if(block_local_id == 0)
gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
tmp_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
tmp_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> tmp_count_thread_buf;
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
// Step 1: each workgroup does local welford reduction
auto threadwise_welford_1 = ThreadwiseWelford1();
threadwise_welford_1.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
count_thread_buf(I) = threadwise_welford_1.cur_count_;
BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
});
// Step 2: each workgroup writes its local welford result to workspace memory
auto mean_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto var_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto count_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_mean_var_store_m_g =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_count_store_m_g =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
if(thread_k_cluster_id == 0)
{
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
mean_thread_buf,
mean_var_count_grid_desc_m_g,
mean_global_val_buf);
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
var_thread_buf,
mean_var_count_grid_desc_m_g,
var_global_val_buf);
threadwise_count_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
count_thread_buf,
mean_var_count_grid_desc_m_g,
count_global_val_buf);
};
gms_barrier(&p_control[blkgroup_id * 2]);
if(block_local_id == 0)
gms_reset(&p_control[blkgroup_id * 2]);
// Step 3: each workgroup reads welford results from workspace memory and does final welford
// reduction
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
count_thread_buf(I) = 0;
});
constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_count_thread_buf);
ThreadwiseWelford2::Run(tmp_mean_thread_buf,
tmp_var_thread_buf,
tmp_count_thread_buf,
mean_thread_buf,
var_thread_buf,
count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_read_fwd_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_read_fwd_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
});
// Step 4: do normalization using the mean/variance
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
YElementwiseOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
y_elementwise_op);
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasSrcVectorSize,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias, bias_grid_desc_m.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y, y_grid_desc_m_k.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_bias_load.Run(bias_grid_desc_m,
bias_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
bias_thread_buf);
threadwise_x_load.SetSrcSliceOrigin(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
AccDataType fused_mean_bias =
bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
// normalize
y_thread_buf(Number<offset>{}) =
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
}
// Step 5: update the moving average of mean and variance (optional)
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_var_thread_buf;
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
running_mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
running_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf);
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
mean_thread_buf[I] * averageFactor;
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
var_thread_buf[I] * averageFactor;
});
auto threadwise_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf,
mean_var_grid_desc_m,
running_mean_global_buf);
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf,
mean_var_grid_desc_m,
running_var_global_buf);
};
// Step 6: save mean and inv-variance (optional)
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
{
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
});
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
result_mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
var_thread_buf,
mean_var_grid_desc_m,
result_inv_var_global_buf);
};
}
}; // namespace ck
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
0,
1,
InMemoryDataOperationEnum::Set,
1,
......@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
0,
1,
InMemoryDataOperationEnum::Set,
1,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
......@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon,
p_in_welford_mean,
p_in_welford_variance,
......@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
......@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
0,
1,
1,
true>(
......@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
0,
1,
1,
true>(
......@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
++reducedTiles)
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf,
......@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf,
welford_count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -109,30 +109,57 @@ struct BlockToCTileMap_M00_N0_M01
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_M00_N0_M01Adapt;
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
: BlockToCTileMap_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
const index_t grid_size = M0 * N0;
return M0 * N0;
}
return grid_size;
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
......@@ -140,8 +167,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
......@@ -154,6 +181,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
......@@ -165,11 +236,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M_;
index_t N_;
index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// keep the redundant type argument for backward compatibility
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
using BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>::
BlockToCTileMap_M00_N0_M01Adapt;
};
// 2D slices of column-vectors in 3D space
......@@ -543,4 +621,52 @@ struct OffsettedBlockToCTileMap
index_t block_start_;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
*
* @paragraph Description
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
*
* @tparam MPerBlock Output block tile size in M dimension.
* @tparam NPerBlock Output block tile size in N dimension.
*/
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_3DGrid_KSplit
{
__host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
__host__ __device__ constexpr auto
CalculateGridSize(index_t M, index_t N, index_t k_split) const
{
// Create 3D grid
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return std::make_tuple(N0, M0, k_split);
}
template <typename TopIdx>
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
{
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : F[M, N0], where N0 is number of blocks along N dimension
// output : G[M, N0], where N0 is number of blocks along N dimension
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// F, G = welford(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension for E
template <typename ABDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EMeanVarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType),
c_block_size * sizeof(CShuffleDataType));
}
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_pass_through_transform(NBlock)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
return grid_desc_mblock_mperblock_nblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C, Welford and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>,
false>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_der_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
false>{};
// LDS c_shuffle_block_desc_mperblock_nperblock
constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(PostShuffleThreadClusterSize_M_N::At(I0) *
PostShuffleThreadClusterSize_M_N::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
PostShuffleThreadClusterSize_M_N::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
PostShuffleThreadClusterSize_M_N::At(I1) ==
0,
"wrong!");
constexpr index_t PostShuffleThreadSliceSize_M =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
PostShuffleThreadClusterSize_M_N::At(I0);
constexpr index_t PostShuffleThreadSliceSize_N =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
PostShuffleThreadClusterSize_M_N::At(I1);
constexpr auto PostShuffleThreadSliceSize_M_N =
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
// VGPR post_shuffle_thread_desc_m_n
constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
// To apply D0, D1, ... and Welford.
// threadwise copy from LDS to VGPR
constexpr auto post_shuffle_thread_cluster_desc =
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
const auto post_shuffle_thread_cluster_idx =
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto post_shuffle_thread_data_idx_begin =
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
// To apply D0, D1, ... and Welford.
// Copy c shuffle from LDS back to VGPR
auto post_shuffle_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
AccDataType,
decltype(c_shuffle_block_desc_mperblock_nperblock),
decltype(post_shuffle_thread_desc_m_n),
decltype(PostShuffleThreadSliceSize_M_N),
Sequence<0, 1>,
1,
PostShuffleScalarPerVector,
1,
true>{c_shuffle_block_desc_mperblock_nperblock,
post_shuffle_thread_data_idx_begin};
// D0, D1, ..., Dn
constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<PostShuffleThreadSliceSize_M>{},
I1,
Number<PostShuffleThreadSliceSize_N>{}));
// FIXME: Decrease usage of VGPR
// Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
auto ds_thread_buf = generate_tuple(
[&](auto) {
return make_static_buffer<AddressSpaceEnum::Vgpr, CShuffleDataType>(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
},
Number<NumDTensor>{});
// Copy D0, D1, ..., Dn from global to VGPR
auto ds_thread_copy_global_to_vgpr = generate_tuple(
[&](auto I) {
using DDataType = remove_cvref_t<tuple_element_t<I.value, DsDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
DDataType,
AccDataType,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>,
Sequence<0, 1, 2, 3>,
3,
PostShuffleScalarPerVector,
1,
true>(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EMeanVarDataType,
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
PostShuffleScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
true>{
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}};
// Welford
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
decltype(thread_welford_src_desc_m_k),
decltype(thread_welford_dst_desc_m)>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
PostShuffleThreadClusterSize_M_N,
Sequence<0, 1>,
false>;
constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
constexpr int num_shuffleN =
NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl);
using mean_var_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_work_idx[I1] % nblock == nblock - 1)
{
constexpr index_t NPerShuffleBlock =
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl;
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
int thread_max_len =
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
int shuffle_step = 0;
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
{
++shuffle_step;
thread_max_len += NPerShuffleBlock;
}
int delta = 0;
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
delta = 0;
else if(NPerBlockTail > thread_max_len)
delta = PostShuffleThreadSliceSize_N;
else
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
}
static_for<0, num_shuffleM, 1>{}([&](auto i) {
threadwise_welfords(i).max_count_ = max_count;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
});
});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!");
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to read from LDS
block_sync_lds();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to write to LDS
block_sync_lds();
// Get shuffle data from LDS to VGPR
post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
post_shuffle_thread_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
// Global read D0, D1, ...
static_for<0, NumDTensor, 1>{}([&](auto Id) {
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id);
d_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
ds_grid_buf[Id],
post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
ds_thread_buf(Id));
if constexpr(access_id < num_access - 1)
{
// move on D0, D1, ...
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step);
}
});
// cde_element_op(e, c, d0, d1, ...);
static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) {
const auto c_ds_src_data_refs = concat_tuple_of_reference(
tie(e_thread_buf[i]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
Number<NumDTensor>{}));
auto e_dst_data_refs = tie(e_thread_buf(i));
unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
});
// Global write E
e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
if constexpr(access_id < num_access - 1)
{
// move on E
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
e_thread_copy_vgpr_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step);
}
// Threadwise welford
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
constexpr int shuffleMInc =
de_global_step[I1] /
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
}
}); // copy c, d, e + welford
// Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
if(post_shuffle_thread_cluster_idx[I1] == 0)
{
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
auto mean_var_count_thread_copy_index = make_multi_index(
block_work_idx[I0], // mblock
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]); // nblock
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EMeanVarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>{mean_var_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
mean_grid_buf); // write mean
mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf); // write variance
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
// to be written.
if(i == 0 && block_work_idx[I0] == 0 &&
post_shuffle_thread_cluster_idx[I0] == 0)
{
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{count_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); // write count
}
}
});
} // shuffle C + Ds + welford + write out
} // run
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
template <typename EMeanVarDataType,
typename HDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation,
index_t BlockSize,
index_t MThreadClusterSize,
index_t NThreadClusterSize,
index_t MThreadSliceSize,
index_t NThreadSliceSize,
index_t ESrcVectorSize,
index_t HDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
static_assert(NThreadSliceSize % ESrcVectorSize == 0 &&
NThreadSliceSize % GammaSrcVectorSize == 0 &&
NThreadSliceSize % BetaSrcVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(NThreadSliceSize % HDstVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder = Sequence<0, 1>;
using ThreadClusterArrangeOrder = Sequence<0, 1>;
static constexpr auto thread_cluster_desc_m_n =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
using ThreadBufferLengths_N = Sequence<NThreadSliceSize>;
static constexpr auto thread_buffer_desc_n =
make_naive_tensor_descriptor_packed(make_tuple(Number<NThreadSliceSize>{}));
using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadWelfordDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<ComputeDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_N,
ThreadClusterArrangeOrder>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
index_t NBlockClusterLength,
ComputeDataType epsilon,
HElementwiseOperation h_element_op)
{
// Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
block_global_id % NBlockClusterLength);
const auto thread_cluster_idx =
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
// Global Memory
const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * NThreadSliceSize,
true>
h_thread_buf;
// IO
auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
1,
1,
1,
true>(
mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
1,
1,
1,
true>(
mean_var_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_count_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
CountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
1,
1,
1,
true>(
count_grid_desc_m_nblock,
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_e_load_m_n =
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
decltype(e_grid_desc_m_n),
decltype(thread_buffer_desc_m_n),
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
1, // SrcVectorDim
ESrcVectorSize,
1,
true>(
e_grid_desc_m_n,
make_multi_index(
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
auto threadwise_gamma_load_n =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
ComputeDataType,
decltype(gamma_grid_desc_n),
decltype(thread_buffer_desc_n),
ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_beta_load_n =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
ComputeDataType,
decltype(beta_grid_desc_n),
decltype(thread_buffer_desc_n),
ThreadBufferLengths_N,
Sequence<0>, // DimAccessOrder,
0, // SrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_n,
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
thread_n_cluster_id * NThreadSliceSize));
auto threadwise_h_store_m_n =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
HDataType,
decltype(thread_buffer_desc_m_n),
decltype(h_grid_desc_m_n),
HElementwiseOperation,
ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder,
1, // DstVectorDim
HDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
h_grid_desc_m_n,
make_multi_index(
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
h_element_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_n =
make_multi_index(I0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
{
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
mean_var_count_thread_copy_step_I0_n);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
// step2: normalization
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
e_global_val_buf,
thread_buffer_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon);
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) =
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
});
});
threadwise_gamma_load_n.Run(gamma_grid_desc_n,
gamma_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
});
threadwise_beta_load_n.Run(beta_grid_desc_n,
beta_global_val_buf,
thread_buffer_desc_n,
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
});
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
make_tuple(I0, I0),
h_thread_buf,
h_grid_desc_m_n,
h_global_val_buf);
} // run
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment