Commit 41b920e2 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 874a78f9 5d718e6b
...@@ -531,8 +531,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -531,8 +531,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -39,9 +39,8 @@ __global__ void ...@@ -39,9 +39,8 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -668,26 +667,24 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -668,26 +667,24 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
return false; return false;
} }
const std::string device_name = ck::get_device_name(); if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
// TODO add newer Navi arch
if(device_name != "gfx906" and device_name != "gfx908" and device_name != "gfx90a" and
device_name != "gfx1030")
{
return false;
}
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
{ {
return false; if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
{
return false;
}
} }
return true;
}
else
{
return false;
} }
return true;
} }
// polymorphic // polymorphic
......
...@@ -44,7 +44,7 @@ __global__ void ...@@ -44,7 +44,7 @@ __global__ void
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
...@@ -39,7 +39,7 @@ __global__ void ...@@ -39,7 +39,7 @@ __global__ void
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -650,32 +650,52 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -650,32 +650,52 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set; constexpr auto Set = InMemoryDataOperationEnum::Set;
if(arg.k_batch_ > 1) // For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
// in IsSupportedArgument function
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{ {
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
ave_time = ave_time = launch_kernel(integral_constant<bool, true>{},
launch_kernel(integral_constant<bool, true>{}, integral_constant<InMemoryDataOperationEnum, Set>{});
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
} }
else else
{ {
ave_time = ave_time = launch_kernel(integral_constant<bool, false>{},
launch_kernel(integral_constant<bool, false>{}, integral_constant<InMemoryDataOperationEnum, Set>{});
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
} }
} }
else else
{ {
if(has_main_k_block_loop) if(arg.k_batch_ > 1)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, if(has_main_k_block_loop)
integral_constant<InMemoryDataOperationEnum, Set>{}); {
ave_time = launch_kernel(
integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time = launch_kernel(
integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, if(has_main_k_block_loop)
integral_constant<InMemoryDataOperationEnum, Set>{}); {
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
} }
} }
...@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
} }
} }
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
// instruction that supports bf16 and we cannot use splitk because of that
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{
supported = supported & (arg.k_batch_ == 1);
}
return supported; return supported;
} }
......
...@@ -35,7 +35,7 @@ __global__ void ...@@ -35,7 +35,7 @@ __global__ void
const index_t group_count) const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -57,7 +57,7 @@ __global__ void ...@@ -57,7 +57,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -75,6 +75,15 @@ struct Add ...@@ -75,6 +75,15 @@ struct Add
y = ck::type_convert<bhalf_t>(y_tmp); y = ck::type_convert<bhalf_t>(y_tmp);
} }
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x0 + x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
...@@ -156,7 +165,7 @@ struct Subtract ...@@ -156,7 +165,7 @@ struct Subtract
struct Bilinear struct Bilinear
{ {
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename Y, typename X0, typename X1> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
...@@ -175,6 +184,14 @@ struct Bilinear ...@@ -175,6 +184,14 @@ struct Bilinear
y = alpha_ * x0 + beta_ * x1; y = alpha_ * x0 + beta_ * x1;
}; };
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -212,7 +229,8 @@ struct Bilinear ...@@ -212,7 +229,8 @@ struct Bilinear
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
{ {
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1)); y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
}; };
float alpha_; float alpha_;
...@@ -264,6 +282,14 @@ struct AddRelu ...@@ -264,6 +282,14 @@ struct AddRelu
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
}; };
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
...@@ -354,6 +380,70 @@ struct AddFastGelu ...@@ -354,6 +380,70 @@ struct AddFastGelu
e = type_convert<half_t>(x1_f); e = type_convert<half_t>(x1_f);
} }
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c + type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
};
// E = Silu(C + D)
struct AddSilu
{
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c + d;
Silu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
Silu{}.template operator()<half_t>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c + d;
float x1_f = 0;
Silu{}.template operator()<float>(x1_f, x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c + type_convert<float>(d);
float x1_f = 0;
Silu{}.template operator()<float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
}; };
} // namespace element_wise } // namespace element_wise
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,50 +21,11 @@ struct PassThroughPack2 ...@@ -21,50 +21,11 @@ struct PassThroughPack2
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
{
// fake conversion
uint16_t t = ck::bit_cast<uint32_t>(x);
y = ck::bit_cast<ck::f8x2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
{ {
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<half2_t>(t);
} }
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
{
y = x;
}
constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough struct PassThrough
...@@ -156,6 +117,12 @@ struct PassThrough ...@@ -156,6 +117,12 @@ struct PassThrough
y = type_convert<half_t>(x); y = type_convert<half_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, int8_t>(bhalf_t& y, const int8_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const __host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{ {
...@@ -452,27 +419,29 @@ struct FastGelu ...@@ -452,27 +419,29 @@ struct FastGelu
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); // const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float c1 = -2.0 * 0.035677f;
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
y = x * cdf; const float emu = exp(u);
y = x / (1.f + emu);
} }
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); // const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u); const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __expf(u);
#if !CK_WORKAROUND_SWDEV_383542 #if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); y = x * __frcp_rn(1.f + emu);
#else #else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); y = x * __ocml_native_recip_f32(1.f + emu);
#endif #endif
y = x * cdf;
} }
template <> template <>
...@@ -551,6 +520,19 @@ struct Sigmoid ...@@ -551,6 +520,19 @@ struct Sigmoid
}; };
}; };
struct Silu
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, ck::half_t> ||
is_same_v<T, int8_t> || is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck::math::exp(-x)));
};
};
struct TanH struct TanH
{ {
template <typename T> template <typename T>
......
...@@ -67,7 +67,7 @@ __global__ void ...@@ -67,7 +67,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -28,8 +28,7 @@ __global__ void ...@@ -28,8 +28,7 @@ __global__ void
#endif #endif
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx1101__) || defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
......
...@@ -54,8 +54,7 @@ __global__ void ...@@ -54,8 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -148,8 +147,7 @@ __global__ void ...@@ -148,8 +147,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map) const Block2CTileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// printf("entry kernel launch"); // printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
...@@ -244,8 +242,7 @@ __global__ void ...@@ -244,8 +242,7 @@ __global__ void
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -274,7 +271,7 @@ __global__ void ...@@ -274,7 +271,7 @@ __global__ void
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__ )) #endif // end of if (defined(__gfx11__ ))
} }
template < // DataType Family template < // DataType Family
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/amd_lds.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -55,8 +56,7 @@ __global__ void ...@@ -55,8 +56,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -492,22 +492,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -492,22 +492,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <typename DataType>
__device__ static auto AllocateBlockBuffers(void* p_shared,
int32_t num_elems,
int32_t offset_elems,
int32_t max_lds_align)
{
const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align);
return generate_tuple(
[&](auto i) {
const int32_t local_offset = i * single_buffer_offset;
return make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + local_offset + offset_elems, num_elems);
},
Number<NumGemmKPrefetchStage>{});
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -641,14 +625,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -641,14 +625,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buffers = AllocateBlockBuffers<AComputeDataType>( const auto a_buffers_offset = 0;
p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align); auto a_block_buffers =
ck::lds_utils::AllocateLdsBuffers<AComputeDataType, NumGemmKPrefetchStage>(
p_shared,
a_block_desc_ak0_m_ak1.GetElementSpaceSize(),
a_buffers_offset,
max_lds_align);
const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage;
auto b_block_buffers = auto b_block_buffers =
AllocateBlockBuffers<BComputeDataType>(p_shared, ck::lds_utils::AllocateLdsBuffers<BComputeDataType, NumGemmKPrefetchStage>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), p_shared,
b_buffers_offset, b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
max_lds_align); b_buffers_offset,
max_lds_align);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
......
...@@ -55,7 +55,7 @@ __global__ void ...@@ -55,7 +55,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -49,8 +49,7 @@ __global__ void ...@@ -49,8 +49,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -75,7 +74,7 @@ __global__ void ...@@ -75,7 +74,7 @@ __global__ void
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
......
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -54,7 +54,7 @@ __global__ void ...@@ -54,7 +54,7 @@ __global__ void
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
...@@ -58,7 +58,7 @@ __global__ void ...@@ -58,7 +58,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue // TODO ANT: separate into MMA + Epilogue
......
...@@ -167,7 +167,7 @@ __global__ void ...@@ -167,7 +167,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment