Commit 686212eb authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'lds_bypass_spilling' into lds_option_passthrough

parents 579f84c6 bdd0f64e
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
using InDataType = F16; using InDataType = F16;
using WeiDataType = F16; using WeiDataType = F16;
using OutDataType = F16; using OutDataType = F16;
...@@ -12,6 +14,46 @@ using InElementOp = PassThrough; ...@@ -12,6 +14,46 @@ using InElementOp = PassThrough;
using WeiElementOp = PassThrough; using WeiElementOp = PassThrough;
using OutElementOp = PassThrough; using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType, InDataType,
...@@ -54,7 +14,18 @@ template <ck::index_t NDimSpatial> ...@@ -54,7 +14,18 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
constexpr ck::index_t split_k = 2; ck::index_t split_k;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{
split_k = 2;
}
else
{
split_k = 1;
}
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
...@@ -144,7 +115,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -144,7 +115,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std::cerr << "wrong! device_conv with the specified compilation parameters does " std::cerr << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem" "not support this Conv problem"
<< std::endl; << std::endl;
return false; return true;
} }
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
......
...@@ -5,9 +5,6 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,9 +5,6 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
...@@ -17,8 +14,3 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft ...@@ -17,8 +14,3 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
if(GPU_TARGETS MATCHES "gfx1100")
add_custom_target(example_gemm_scale_softmax_gemm_wmma)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16)
endif()
...@@ -168,6 +168,9 @@ ...@@ -168,6 +168,9 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_383542 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -65,20 +65,6 @@ struct BlockwiseGemmWMMA ...@@ -65,20 +65,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr auto A_temp0 = Number<ABlockDesc{}.GetLength(I0)>{};
static constexpr auto A_temp1 = Number<ABlockDesc{}.GetLength(I1)>{};
static constexpr auto A_temp2 = Number<ABlockDesc{}.GetLength(I2)>{};
static constexpr auto A_temp3 = Number<ABlockDesc{}.GetLength(I3)>{};
static constexpr auto A_temp4 = Number<ABlockDesc{}.GetLength(I4)>{};
// FIX it, workaround
using ABlockDesc_temp = decltype(
make_naive_tensor_descriptor(make_tuple(A_temp0, A_temp1, A_temp2, A_temp3, A_temp4),
make_tuple(A_temp1* A_temp2* A_temp3* A_temp4,
A_temp2* A_temp3* A_temp4,
A_temp3* A_temp4,
A_temp4,
I1)));
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
...@@ -210,9 +196,6 @@ struct BlockwiseGemmWMMA ...@@ -210,9 +196,6 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
// constexpr auto NSubGroup =
// c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MThreadPerSubGroup
// = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
...@@ -302,7 +285,7 @@ struct BlockwiseGemmWMMA ...@@ -302,7 +285,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer // Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc_temp a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
......
...@@ -111,6 +111,7 @@ __global__ void ...@@ -111,6 +111,7 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_d0s_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c0de_element_op; ignore = c0de_element_op;
......
...@@ -586,6 +586,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -586,6 +586,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false; return false;
} }
if(ck::get_device_name() != "gfx90a" && std::is_same<ADataType, double>::value)
{
return false;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -280,43 +281,42 @@ struct AddHardswish ...@@ -280,43 +281,42 @@ struct AddHardswish
}; };
}; };
// C = A * B
// E = FastGelu(C + D) // E = FastGelu(C + D)
struct AddFastGelu struct AddFastGelu
{ {
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D> template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{ {
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> && const float x = c + d;
is_valid_param_type_v<D>);
FastGelu{}.template operator()<float, float>(e, x);
}
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d)); template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
e = type_convert<E>(y); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
template <typename D> template <>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const __host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{ {
static_assert(is_valid_param_type_v<D>); const float x0_f = c + d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = GetFastGeLU(c + type_convert<float>(d)); e = type_convert<half_t>(x1_f);
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace element_wise { ...@@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion // siliently do implicit type conversion
// //
// Method 1: // Example:
// //
// struct ExampleElementwiseOp // struct ExampleElementwiseOp
// { // {
...@@ -30,19 +30,6 @@ namespace element_wise { ...@@ -30,19 +30,6 @@ namespace element_wise {
// { // {
// } // }
// }; // };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd struct AddReluAdd
{ {
...@@ -208,41 +195,74 @@ struct AddMultiply ...@@ -208,41 +195,74 @@ struct AddMultiply
} }
}; };
// C = A * B
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
// Fast GeLU template <typename E, typename C, typename D0, typename D1>
// https://paperswithcode.com/method/gelu __host__ __device__ constexpr void
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float x = c + d0 + d1;
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); FastGelu{}.template operator()<float, float>(e, x);
return x * cdf;
} }
template <typename T> template <>
static inline constexpr bool is_valid_param_type_v = __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> || half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> {
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 const half_t x = c + d0 + d1;
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1> ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
__host__ __device__ constexpr void }
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{ {
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> && const float x0_f = c + d0 + d1;
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
{
const float x0_f =
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
const float y = ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1)); x0_f);
e = type_convert<E>(y); e = type_convert<int8_t>(x1_f);
} }
}; };
......
...@@ -11,6 +11,10 @@ namespace ck { ...@@ -11,6 +11,10 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThrough struct PassThrough
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -200,36 +204,83 @@ struct Relu ...@@ -200,36 +204,83 @@ struct Relu
} }
}; };
// Y = FastGelu(X) // Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
struct FastGelu struct FastGelu
{ {
// Fast GeLU template <typename Y, typename X>
// https://paperswithcode.com/method/gelu __host__ void operator()(Y& y, const X& x) const;
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x) template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
y = x * cdf;
} }
template <typename T> // device code, use lower precision "__expf" and "rcp"
static inline constexpr bool is_valid_param_type_v = template <>
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> || __device__ void operator()<float, float>(float& y, const float& x) const
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> {
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|| std::is_same_v<T, ck::int4_t> const float emu = __expf(-u);
#if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
#else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
#endif #endif
;
template <typename Y, typename X> y = x * cdf;
__host__ __device__ void operator()(Y& y, const X& x) const }
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>); float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
const float tmp_y = GetFastGeLU(type_convert<float>(x)); y = type_convert<half_t>(y_f);
y = type_convert<Y>(tmp_y);
} }
}; };
......
...@@ -247,20 +247,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -247,20 +247,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// Err: merge transform cause non-constexpr issue
// return transform_tensor_descriptor(
// ABlockDesc_{},
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
// make_pass_through_transform(Number<MRepeat>{}),
// make_pass_through_transform(I1),
// make_pass_through_transform(I1),
// make_pass_through_transform(Number<A_K1>{})),
// make_tuple(Sequence<0, 3>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
// Sequence<4>{}));
// Workaround, Freeze transform
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)), make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}), make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{}, make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple( make_tuple(Sequence<>{},
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -456,14 +481,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -456,14 +481,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto a_block_space_size_aligned = static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align) * max_lds_align)
sizeof(FloatA)
: 0; : 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple( BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(), GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align) * max_lds_align)
sizeof(FloatB)
: 0; : 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
...@@ -472,13 +495,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -472,13 +495,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size = static constexpr auto c_shuffle_block_space_size =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize() * .GetElementSpaceSize();
sizeof(FloatCShuffle);
static constexpr auto c_shuffle_block_space_offset = 0; static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size = math::max( static constexpr auto lds_size =
c_shuffle_block_space_size, (a_block_space_size_aligned + b_block_space_size_aligned)); math::max(c_shuffle_block_space_size * sizeof(FloatCShuffle),
a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
...@@ -539,7 +563,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -539,7 +563,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared), static_cast<FloatA*>(p_shared),
a_block_desc.GetElementSpaceSize()); SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -614,8 +638,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -614,8 +638,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned, static_cast<FloatB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc.GetElementSpaceSize()); SharedMemTrait::b_block_space_size_aligned);
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -726,13 +750,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -726,13 +750,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
#if 0
static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
printf("tid: %03d, c_thread_buf[%02d] val: %08x\n", get_thread_local_1d_id(), i.value,
*(reinterpret_cast<const uint32_t*>(&(c_thread_buf[i]))));
// c_thread_buf(i) = 32;
});
#endif
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
...@@ -751,7 +768,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -751,7 +768,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), SharedMemTrait::c_shuffle_block_space_size); static_cast<FloatCShuffle*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
SharedMemTrait::c_shuffle_block_space_size);
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
......
...@@ -10,8 +10,8 @@ cmake ...@@ -10,8 +10,8 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=OFF \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx90a" \ -D GPU_TARGETS="gfx908;gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment