Unverified Commit 3f299c33 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/ggemm_dl_instances

parents 507d793a 091570f5
...@@ -76,4 +76,8 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -76,4 +76,8 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perchannel_quantization_example.inc" #include "run_conv2d_fwd_perchannel_quantization_example.inc"
int main() { run_conv2d_fwd_perchannel_quantization_example(); } int main()
{
const auto out_element_op = OutElementOp{ActivationOp{}};
run_conv2d_fwd_perchannel_quantization_example(out_element_op);
}
...@@ -71,4 +71,9 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -71,4 +71,9 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perlayer_quantization_example.inc" #include "run_conv2d_fwd_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op);
}
...@@ -80,6 +80,10 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -80,6 +80,10 @@ using DeviceGroupedConvNDFwdInstance =
S<1, 64, 1, 4>, S<1, 64, 1, 4>,
8>; 8>;
#include "run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc" #include "run_conv2d_fwd_bias_perchannel_quantization_example.inc"
int main() { run_conv2d_fwd_bias_relu_perchannel_quantization_example(); }; int main()
{
const auto out_element_op = OutElementOp{ActivationOp{}};
run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op);
};
...@@ -78,6 +78,11 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -78,6 +78,11 @@ using DeviceGroupedConvNDFwdInstance =
S<1, 64, 1, 4>, S<1, 64, 1, 4>,
8>; 8>;
#include "run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc" #include "run_conv2d_fwd_bias_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_bias_relu_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op);
}
...@@ -80,4 +80,8 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -80,4 +80,8 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perchannel_quantization_example.inc" #include "run_conv2d_fwd_perchannel_quantization_example.inc"
int main() { run_conv2d_fwd_perchannel_quantization_example(); } int main()
{
const auto out_element_op = OutElementOp{ActivationOp{}};
run_conv2d_fwd_perchannel_quantization_example(out_element_op);
}
...@@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance =
#include "run_conv2d_fwd_perlayer_quantization_example.inc" #include "run_conv2d_fwd_perlayer_quantization_example.inc"
int main() { run_conv2d_fwd_perlayer_quantization_example(); } int main()
{
float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op);
}
...@@ -167,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -167,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_bias_relu_perchannel_quantization_example() int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -189,7 +189,6 @@ int run_conv2d_fwd_bias_relu_perchannel_quantization_example() ...@@ -189,7 +189,6 @@ int run_conv2d_fwd_bias_relu_perchannel_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_bias_relu_perlayer_quantization_example() int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example() ...@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_perchannel_quantization_example() int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
...@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example() ...@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
int run_conv2d_fwd_perlayer_quantization_example() int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op)
{ {
bool do_verification = true; bool do_verification = true;
bool time_kernel = false; bool time_kernel = false;
...@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example() ...@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example()
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC; using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC; using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
// FMA instruction // FMA instruction
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -7,10 +7,30 @@ namespace ck { ...@@ -7,10 +7,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp(float requantScale, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
...@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp ...@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp
Activation activationOp_; Activation activationOp_;
}; };
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float y_fp32 = ck::type_convert<float>(x);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as // Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc // relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Activation_Mul2_Clamp struct Activation_Mul2_Clamp
{ {
...@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp ...@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp
}; };
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
...@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp ...@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp
}; };
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc // For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) // Convolution + Activation (non piecewise linear function)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) // Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{ {
} }
...@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp ...@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x + bias); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32; y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float requantScale1_; __host__ __device__ constexpr void
float requantScale2_; operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -320,6 +320,19 @@ struct Sigmoid ...@@ -320,6 +320,19 @@ struct Sigmoid
int32_t divider_ = 1; int32_t divider_ = 1;
}; };
struct TanH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
};
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_desc_k0perblock_nperblock_k1 = constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(ADataType) + constexpr auto c_block_space_size_aligned = math::integer_least_multiple(
b_block_space_size_aligned * sizeof(BDataType)); cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(),
max_lds_align);
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)),
c_block_space_size_aligned * sizeof(CShuffleDataType));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if defined(__gfx90a__)
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
using ABDataTypeAdjusted = ABDataType;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABDataType, ABDataType,
ABDataType, ABDataTypeAdjusted,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ABDataType, ABDataType,
ABDataType, ABDataTypeAdjusted,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataTypeAdjusted,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -166,15 +166,12 @@ __global__ void ...@@ -166,15 +166,12 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size = __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared,
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
// M0/M1/M1Padding // M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{}; static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{}; static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{};
...@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatABAdjusted,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk); math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
...@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared), static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment