"docs/source/ko/index.mdx" did not exist on "02d83c9ff1b93f2c6f9c94f9369b3e4bc1ba8ce7"
Commit dbb7002d authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/hotloop

parents 96c8d948 2bef5501
...@@ -60,8 +60,7 @@ __global__ void ...@@ -60,8 +60,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -103,7 +102,7 @@ __global__ void ...@@ -103,7 +102,7 @@ __global__ void
compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0);
compute_ptr_offset_of_batch.GetCPtrOffset(0); compute_ptr_offset_of_batch.GetCPtrOffset(0);
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <index_t NDimSpatial, template <index_t NDimSpatial,
......
...@@ -55,8 +55,7 @@ __global__ void ...@@ -55,8 +55,7 @@ __global__ void
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block) [[maybe_unused]] const index_t num_k_per_block)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
...@@ -85,7 +84,7 @@ __global__ void ...@@ -85,7 +84,7 @@ __global__ void
k_idx); k_idx);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -145,7 +144,7 @@ __global__ void ...@@ -145,7 +144,7 @@ __global__ void
k_idx); k_idx);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -98,8 +99,7 @@ __global__ void ...@@ -98,8 +99,7 @@ __global__ void
const ComputePtrOffsetOfG compute_ptr_offset_of_groups, const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
const ComputePtrOffsetOfN compute_ptr_offset_of_n) const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
...@@ -212,9 +212,13 @@ __global__ void ...@@ -212,9 +212,13 @@ __global__ void
} }
} // namespace } // namespace
#ifdef CK_CODE_GEN_RTC
template <typename T>
using is_tuple = decltype(ck::declval<T&>().IsTuple());
#else
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
#endif
// //
// @brief Device Convolution operation. // @brief Device Convolution operation.
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -117,7 +118,7 @@ __global__ void ...@@ -117,7 +118,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock); c_grid_desc_mblock_mperblock_nblock_nperblock);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -183,7 +184,7 @@ __global__ void ...@@ -183,7 +184,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock); c_grid_desc_mblock_mperblock_nblock_nperblock);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
} // namespace } // namespace
......
...@@ -155,8 +155,7 @@ __global__ void ...@@ -155,8 +155,7 @@ __global__ void
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -52,8 +52,7 @@ __global__ void ...@@ -52,8 +52,7 @@ __global__ void
const ComputePtrOffset compute_ptr_offset_of_groups, const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n) const ComputePtrOffset compute_ptr_offset_of_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
......
...@@ -68,8 +68,7 @@ __global__ void ...@@ -68,8 +68,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
...@@ -404,7 +403,7 @@ __global__ void ...@@ -404,7 +403,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
......
...@@ -43,8 +43,7 @@ __global__ void ...@@ -43,8 +43,7 @@ __global__ void
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -109,7 +108,7 @@ __global__ void ...@@ -109,7 +108,7 @@ __global__ void
ignore = acc_element_op; ignore = acc_element_op;
ignore = b1_element_op; ignore = b1_element_op;
ignore = c_element_op; ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
......
...@@ -38,8 +38,7 @@ __global__ void ...@@ -38,8 +38,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
...@@ -50,8 +50,7 @@ __global__ void ...@@ -50,8 +50,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
...@@ -40,8 +40,7 @@ __global__ void ...@@ -40,8 +40,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
...@@ -80,7 +79,7 @@ __global__ void ...@@ -80,7 +79,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
......
...@@ -56,8 +56,7 @@ __global__ void ...@@ -56,8 +56,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
......
...@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout ...@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
} // namespace convolution } // namespace convolution
#ifndef CK_CODE_GEN_RTC
template < template <
typename Layout, typename Layout,
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false> typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
...@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&) ...@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os << Layout::name; os << Layout::name;
return os; return os;
} }
#endif
} // namespace tensor_layout } // namespace tensor_layout
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -340,8 +340,8 @@ struct Bilinear ...@@ -340,8 +340,8 @@ struct Bilinear
}; };
template <> template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
{ {
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) + y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1)); beta_ * type_convert<float>(x1));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -533,7 +533,7 @@ struct NormalizeInInfer ...@@ -533,7 +533,7 @@ struct NormalizeInInfer
const T3& gamma, const T3& gamma,
const T4& beta) const const T4& beta) const
{ {
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value, static_assert(is_same<T2, float>::value || is_same<T2, double>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
using ck::type_convert; using ck::type_convert;
......
...@@ -16,7 +16,8 @@ namespace ck { ...@@ -16,7 +16,8 @@ namespace ck {
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] // [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation: // (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q) // Convert lower part of packed int4 -> int4 to half
__device__ inline half4_t i4_to_half4(int q)
{ {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
...@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale) __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
{ {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
...@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& ...@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) __device__ inline bhalf4_t i4_to_bhalf4(int q)
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
vector_type<half_t, 2> res;
half_t x_h = (x_u8 & 0x0f) - 8;
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
res.template AsType<half_t>()(Number<0>{}) = x_l;
res.template AsType<half_t>()(Number<1>{}) = x_h;
return res.template AsType<half2_t>()[Number<0>{}];
#endif
}
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
{ {
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
...@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) ...@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
return res.template AsType<bhalf4_t>()[Number<0>{}]; return res.template AsType<bhalf4_t>()[Number<0>{}];
} }
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
vector_type<bhalf_t, 2> res;
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
return res.template AsType<bhalf2_t>()[Number<0>{}];
}
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
...@@ -159,11 +118,11 @@ struct PassThroughPack8 ...@@ -159,11 +118,11 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{ {
#if 1 #if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result; vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x)); result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8); result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}]; y = result.template AsType<half8_t>()[Number<0>{}];
#else #else
...@@ -171,13 +130,13 @@ struct PassThroughPack8 ...@@ -171,13 +130,13 @@ struct PassThroughPack8
vector_type<pk_i4_t, 4> src{x}; vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) = dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) = dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) = dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) = dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}]; y = dst.template AsType<half8_t>()[Number<0>{}];
#endif #endif
...@@ -185,11 +144,11 @@ struct PassThroughPack8 ...@@ -185,11 +144,11 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{ {
#if 1 #if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result; vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x)); result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16); result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y = result.template AsType<bhalf8_t>()[Number<0>{}]; y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else #else
...@@ -197,13 +156,13 @@ struct PassThroughPack8 ...@@ -197,13 +156,13 @@ struct PassThroughPack8
vector_type<pk_i4_t, 4> src{x}; vector_type<pk_i4_t, 4> src{x};
dst.template AsType<bhalf2_t>()(Number<0>{}) = dst.template AsType<bhalf2_t>()(Number<0>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]); type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<bhalf2_t>()(Number<1>{}) = dst.template AsType<bhalf2_t>()(Number<1>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]); type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<bhalf2_t>()(Number<2>{}) = dst.template AsType<bhalf2_t>()(Number<2>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]); type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<bhalf2_t>()(Number<3>{}) = dst.template AsType<bhalf2_t>()(Number<3>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]); type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<bhalf8_t>()[Number<0>{}]; y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif #endif
...@@ -219,12 +178,12 @@ struct DequantPack8 ...@@ -219,12 +178,12 @@ struct DequantPack8
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{ {
#if 1 #if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result; vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z); result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<1>{}) = result.template AsType<half4_t>()(Number<1>{}) =
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z); i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}]; y = result.template AsType<half8_t>()[Number<0>{}];
#else #else
...@@ -232,13 +191,13 @@ struct DequantPack8 ...@@ -232,13 +191,13 @@ struct DequantPack8
vector_type<pk_i4_t, 4> src{x}; vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) = dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) = dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) = dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) = dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]); type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}]; y = dst.template AsType<half8_t>()[Number<0>{}];
#endif #endif
...@@ -252,7 +211,7 @@ struct PassThroughPack2 ...@@ -252,7 +211,7 @@ struct PassThroughPack2
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const
{ {
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<half2_t>(t);
...@@ -260,7 +219,7 @@ struct PassThroughPack2 ...@@ -260,7 +219,7 @@ struct PassThroughPack2
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{ {
#if 1 #if CK_USE_PK4_LAYOUT_SHUFFLE
uint8_t x_u8 = ck::bit_cast<uint8_t>(x); uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0; uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4; uint8_t x_h = (x_u8 & 0xf0) >> 4;
...@@ -479,7 +438,7 @@ struct PassThrough ...@@ -479,7 +438,7 @@ struct PassThrough
template <> template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const __host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{ {
y = ck::type_convert<bf8_t>(x); y = type_convert<bf8_t>(x);
} }
}; };
...@@ -552,21 +511,21 @@ struct Scale ...@@ -552,21 +511,21 @@ struct Scale
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
y = ck::type_convert<Y>(ck::type_convert<float>(x) * scale_); y = type_convert<Y>(type_convert<float>(x) * scale_);
} }
template <> template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
y = ck::type_convert<half_t>(scale_) * x; y = type_convert<half_t>(scale_) * x;
}; };
template <> template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
const float x_tmp = ck::type_convert<float>(x); const float x_tmp = type_convert<float>(x);
const float y_tmp = scale_ * x_tmp; const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp); y = type_convert<bhalf_t>(y_tmp);
}; };
template <> template <>
...@@ -584,7 +543,7 @@ struct Scale ...@@ -584,7 +543,7 @@ struct Scale
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
y = ck::type_convert<int8_t>(scale_ * ck::type_convert<float>(x)); y = type_convert<int8_t>(scale_ * type_convert<float>(x));
}; };
float scale_; float scale_;
...@@ -600,7 +559,7 @@ struct ScaleAndResetNaNToMinusInfinity ...@@ -600,7 +559,7 @@ struct ScaleAndResetNaNToMinusInfinity
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
y = ck::math::isnan(x) ? -ck::NumericLimits<float>::Infinity() : scale_ * x; y = math::isnan(x) ? -NumericLimits<float>::Infinity() : scale_ * x;
}; };
float scale_; float scale_;
...@@ -671,12 +630,13 @@ struct UnaryAbs ...@@ -671,12 +630,13 @@ struct UnaryAbs
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value || is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value, is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::abs(x); y = math::abs(x);
}; };
template <> template <>
...@@ -694,7 +654,7 @@ struct UnarySqrt ...@@ -694,7 +654,7 @@ struct UnarySqrt
static_assert(is_same<T, float>::value || is_same<T, double>::value, static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sqrt(x); y = math::sqrt(x);
}; };
}; };
...@@ -713,9 +673,9 @@ struct Relu ...@@ -713,9 +673,9 @@ struct Relu
template <> template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32); y = type_convert<bhalf_t>(y_f32);
} }
}; };
...@@ -731,7 +691,7 @@ struct FastGelu ...@@ -731,7 +691,7 @@ struct FastGelu
template <typename Y, typename X> template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const; __device__ void operator()(Y& y, const X& x) const;
#ifndef CK_CODE_GEN_RTC
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -742,6 +702,7 @@ struct FastGelu ...@@ -742,6 +702,7 @@ struct FastGelu
const float emu = exp(u); const float emu = exp(u);
y = x / (1.f + emu); y = x / (1.f + emu);
} }
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp" // device code, use lower precision "__ocml_exp_f32" and "rcp"
template <> template <>
...@@ -753,7 +714,7 @@ struct FastGelu ...@@ -753,7 +714,7 @@ struct FastGelu
const float u = x * (c1 * x * x + c2); const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u); const float emu = __ocml_exp_f32(u);
y = x * ck::math::rcp(1.f + emu); y = x * math::rcp(1.f + emu);
} }
template <> template <>
...@@ -851,10 +812,9 @@ struct Gelu ...@@ -851,10 +812,9 @@ struct Gelu
} }
template <> template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y, __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
const ck::half_t& x) const
{ {
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x))));
} }
}; };
...@@ -868,7 +828,7 @@ struct Sigmoid ...@@ -868,7 +828,7 @@ struct Sigmoid
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x)); y = one / (one + math::exp(-x));
}; };
}; };
...@@ -877,11 +837,11 @@ struct Silu ...@@ -877,11 +837,11 @@ struct Silu
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, ck::half_t> || static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, half_t> ||
is_same_v<T, int8_t> || is_same_v<T, int32_t>, is_same_v<T, int8_t> || is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck::math::exp(-x))); y = x * (one / (one + math::exp(-x)));
}; };
}; };
...@@ -895,7 +855,7 @@ struct TanH ...@@ -895,7 +855,7 @@ struct TanH
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tanh(x); y = math::tanh(x);
}; };
}; };
...@@ -905,11 +865,11 @@ struct ACos ...@@ -905,11 +865,11 @@ struct ACos
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::acos(x); y = math::acos(x);
}; };
}; };
...@@ -919,11 +879,11 @@ struct Neg ...@@ -919,11 +879,11 @@ struct Neg
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::neg(x); y = math::neg(x);
}; };
}; };
...@@ -933,11 +893,11 @@ struct ATan ...@@ -933,11 +893,11 @@ struct ATan
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::atan(x); y = math::atan(x);
}; };
}; };
...@@ -947,11 +907,11 @@ struct Sin ...@@ -947,11 +907,11 @@ struct Sin
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sin(x); y = math::sin(x);
}; };
}; };
...@@ -961,11 +921,11 @@ struct ASinH ...@@ -961,11 +921,11 @@ struct ASinH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::asinh(x); y = math::asinh(x);
}; };
}; };
...@@ -975,11 +935,11 @@ struct Cos ...@@ -975,11 +935,11 @@ struct Cos
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::cos(x); y = cos(x);
}; };
}; };
...@@ -989,11 +949,11 @@ struct ACosH ...@@ -989,11 +949,11 @@ struct ACosH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::acosh(x); y = math::acosh(x);
}; };
}; };
...@@ -1003,11 +963,11 @@ struct Tan ...@@ -1003,11 +963,11 @@ struct Tan
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tan(x); y = math::tan(x);
}; };
}; };
...@@ -1017,11 +977,11 @@ struct ATanH ...@@ -1017,11 +977,11 @@ struct ATanH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::atanh(x); y = math::atanh(x);
}; };
}; };
...@@ -1031,11 +991,11 @@ struct SinH ...@@ -1031,11 +991,11 @@ struct SinH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sinh(x); y = math::sinh(x);
}; };
}; };
...@@ -1045,11 +1005,11 @@ struct Ceil ...@@ -1045,11 +1005,11 @@ struct Ceil
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::ceil(x); y = math::ceil(x);
}; };
}; };
...@@ -1059,11 +1019,11 @@ struct Exp ...@@ -1059,11 +1019,11 @@ struct Exp
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::exp(x); y = math::exp(x);
}; };
}; };
...@@ -1073,11 +1033,11 @@ struct CosH ...@@ -1073,11 +1033,11 @@ struct CosH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::cosh(x); y = math::cosh(x);
}; };
}; };
...@@ -1087,11 +1047,11 @@ struct Floor ...@@ -1087,11 +1047,11 @@ struct Floor
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::floor(x); y = math::floor(x);
}; };
}; };
...@@ -1101,11 +1061,11 @@ struct Log ...@@ -1101,11 +1061,11 @@ struct Log
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::log(x); y = math::log(x);
}; };
}; };
...@@ -1115,11 +1075,11 @@ struct ASin ...@@ -1115,11 +1075,11 @@ struct ASin
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::asin(x); y = math::asin(x);
}; };
}; };
...@@ -1129,11 +1089,11 @@ struct Rcp ...@@ -1129,11 +1089,11 @@ struct Rcp
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::rcp(x); y = math::rcp(x);
}; };
}; };
...@@ -1153,7 +1113,7 @@ struct Swish ...@@ -1153,7 +1113,7 @@ struct Swish
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + math::exp(bx)));
}; };
const float beta_; const float beta_;
...@@ -1172,7 +1132,7 @@ struct SoftRelu ...@@ -1172,7 +1132,7 @@ struct SoftRelu
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_; const float alpha_;
}; };
...@@ -1193,7 +1153,7 @@ struct Power ...@@ -1193,7 +1153,7 @@ struct Power
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_); T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x; T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma); y = math::pow(shifted_scaled_x, casted_gamma);
} }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
...@@ -1213,7 +1173,7 @@ struct ClippedRelu ...@@ -1213,7 +1173,7 @@ struct ClippedRelu
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); y = math::min(casted_beta, math::max(casted_alpha, x));
} }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
...@@ -1248,7 +1208,7 @@ struct Elu ...@@ -1248,7 +1208,7 @@ struct Elu
is_same<T, int8_t>::value, is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); y = x > 0 ? x : casted_alpha * math::expm1(x);
} }
const float alpha_; const float alpha_;
}; };
...@@ -1350,10 +1310,10 @@ struct FastNumericArrayConverter ...@@ -1350,10 +1310,10 @@ struct FastNumericArrayConverter
}; };
template <> template <>
struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> struct FastNumericArrayConverter<uint8_t, half_t, 4>
{ {
using InputArray = vector_type<uint8_t, 4>; using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck::half_t, 4>; using OutputArray = vector_type<half_t, 4>;
__device__ static OutputArray convert(InputArray const& Input) __device__ static OutputArray convert(InputArray const& Input)
{ {
...@@ -1383,13 +1343,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> ...@@ -1383,13 +1343,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
}; };
template <index_t N> template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck::half_t, N> struct FastNumericArrayConverter<uint8_t, half_t, N>
{ {
static constexpr int VEC_WIDTH = 4; static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>; using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck::half_t, N>; using OutputArray = vector_type<half_t, N>;
__device__ static OutputArray convert(InputArray const& Input) __device__ static OutputArray convert(InputArray const& Input)
{ {
...@@ -1398,7 +1358,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -1398,7 +1358,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
OutputArray Output; OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>; using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck::half_t, 4>; using Vec_OutputArray = vector_type<half_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output); Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input); Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef CK_CODE_GEN_RTC
#include <limits> #include <limits>
#include <stdlib.h> #include <stdlib.h>
#endif
namespace ck { namespace ck {
...@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
// Create 3D grid // Create 3D grid
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return make_tuple(N0, M0, k_split);
return std::make_tuple(N0, M0, k_split);
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score = uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++) tentative_sk_blocks++)
{ {
......
...@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
...@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math::lcm( // selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size, //
B1K1), // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk); // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size;
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment