Commit 55cdf2b9 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/layernorm_fusion

parents 4b59b5c9 b098b71b
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -81,8 +81,10 @@ struct space_filling_curve ...@@ -81,8 +81,10 @@ struct space_filling_curve
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{}); return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
} }
// Do not use this function directly!
// TODO: can refactor into generic lambda in the future
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
{ {
#if 0 #if 0
/* /*
...@@ -153,11 +155,11 @@ struct space_filling_curve ...@@ -153,11 +155,11 @@ struct space_filling_curve
return idx_md; return idx_md;
} }
// FIXME: rename this function // FIXME: return tuple of number<>, which is compile time only variable
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
{ {
constexpr auto idx = get_index(number<AccessIdx1d>{}); constexpr auto idx = _get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{}); return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
} }
......
...@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct smem_load_trait;
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
// clang-format on
} // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template <index_t>
struct smem_load;
template <>
struct smem_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u8 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
// clang-format off // clang-format off
namespace impl{ namespace impl{
...@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false> template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t src_linear_addr_offset,
index_t flag = 0, index_t flag = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
0, src_linear_addr_offset,
flag, flag,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
0, src_linear_addr_offset,
flag, flag,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, ...@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
}
template <index_t N, template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data, CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
...@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset, index_t dst_wave_addr_offset,
index_t dst_linear_addr_offset,
index_t is_valid_element = 1) index_t is_valid_element = 1)
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
...@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0, dst_linear_addr_offset,
is_valid_element); is_valid_element);
} }
else else
...@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); dst_linear_addr_offset);
} }
} }
...@@ -2014,6 +2156,7 @@ template <typename T, ...@@ -2014,6 +2156,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
index_t is_valid_element = 0, index_t is_valid_element = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
...@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, ...@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -2041,16 +2186,19 @@ template <typename T, ...@@ -2041,16 +2186,19 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t is_valid_element = 0, index_t is_valid_element = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -2066,6 +2214,7 @@ template <typename T, ...@@ -2066,6 +2214,7 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, ...@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg // This version support buffer resource as input arg
...@@ -2086,12 +2240,42 @@ template <typename T, ...@@ -2086,12 +2240,42 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
} }
// buffer_store requires: // buffer_store requires:
...@@ -2146,6 +2330,7 @@ template <typename T, ...@@ -2146,6 +2330,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data, CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave, T* p_dst_wave,
const index_t dst_thread_element_offset, const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid, const bool dst_thread_element_valid,
const index_t dst_element_space_size) const index_t dst_element_space_size)
{ {
...@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d ...@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data, amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
0, 0,
dst_linear_addr_offset,
dst_thread_element_valid); dst_thread_element_valid);
} }
...@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_ ...@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread> template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
......
...@@ -41,6 +41,19 @@ ...@@ -41,6 +41,19 @@
#define CK_TILE_HOST_DEVICE_EXTERN #define CK_TILE_HOST_DEVICE_EXTERN
#endif #endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif #endif
...@@ -205,3 +218,8 @@ ...@@ -205,3 +218,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA #ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 #define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif #endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
...@@ -623,7 +623,7 @@ template <typename... Ys, ...@@ -623,7 +623,7 @@ template <typename... Ys,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x) CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys); constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y; return y;
...@@ -635,7 +635,7 @@ template <typename... Ys, ...@@ -635,7 +635,7 @@ template <typename... Ys,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x) CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys); constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y; return y;
...@@ -647,7 +647,7 @@ template <typename... Xs, ...@@ -647,7 +647,7 @@ template <typename... Xs,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) ...@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) ...@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a) ...@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x; return a * x;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys> template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y) CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{ {
......
...@@ -487,55 +487,12 @@ struct log2e<float> ...@@ -487,55 +487,12 @@ struct log2e<float>
template <typename T = double> template <typename T = double>
constexpr T log2e_v = log2e<T>::value; constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
float abs(const float& x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
}
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
uint32_t xx = bit_cast<uint32_t>(x);
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); }; float exp2(float x) { return exp2f(x); };
CK_TILE_HOST CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); }; float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc) CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
{ {
return __builtin_amdgcn_sad_u16(x, y, acc); return __builtin_amdgcn_sad_u16(x, y, acc);
...@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc) ...@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
return (x > y ? (x - y) : (y - x)) + acc; return (x > y ? (x - y) : (y - x)) + acc;
} }
///////////////////////////////////////////////////////////////
} // namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace ck_tile {
#if CK_TILE_WORKAROUND_SWDEV_383542
extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST float abs(float x) { return std::abs(x); };
CK_TILE_HOST double abs(double x) { return std::abs(x); };
CK_TILE_HOST int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
}
#endif
CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(int32_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_HOST fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
};
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
template <typename T>
CK_TILE_HOST T tanh(T x)
{
return type_convert<T>(std::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
CK_TILE_HOST double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
CK_TILE_HOST T acos(T x)
{
return type_convert<T>(std::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acos<float>(float x)
{
return std::acosf(x);
};
template <>
CK_TILE_HOST double acos<double>(double x)
{
return std::acos(x);
};
template <typename T>
CK_TILE_HOST T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_HOST double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_HOST int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_HOST int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <typename T>
CK_TILE_HOST T atan(T x)
{
return type_convert<T>(std::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atan<float>(float x)
{
return std::atanf(x);
};
template <>
CK_TILE_HOST double atan<double>(double x)
{
return std::atan(x);
};
template <typename T>
CK_TILE_HOST T sin(T x)
{
return type_convert<T>(std::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sin<float>(float x)
{
return std::sinf(x);
};
template <>
CK_TILE_HOST double sin<double>(double x)
{
return std::sin(x);
};
template <typename T>
CK_TILE_HOST T asin(T x)
{
return type_convert<T>(std::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asin<float>(float x)
{
return std::asinf(x);
};
template <>
CK_TILE_HOST double asin<double>(double x)
{
return std::asin(x);
};
template <typename T>
CK_TILE_HOST T asinh(T x)
{
return type_convert<T>(std::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asinh<float>(float x)
{
return std::asinhf(x);
};
template <>
CK_TILE_HOST double asinh<double>(double x)
{
return std::asinh(x);
};
template <typename T>
CK_TILE_HOST T cos(T x)
{
return type_convert<T>(std::cosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cos<float>(float x)
{
return std::cosf(x);
};
template <>
CK_TILE_HOST double cos<double>(double x)
{
return std::cos(x);
};
template <typename T>
CK_TILE_HOST T acosh(T x)
{
return type_convert<T>(std::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acosh<float>(float x)
{
return std::acoshf(x);
};
template <>
CK_TILE_HOST double acosh<double>(double x)
{
return std::acosh(x);
};
template <typename T>
CK_TILE_HOST T tan(T x)
{
return type_convert<T>(std::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tan<float>(float x)
{
return std::tanf(x);
};
template <>
CK_TILE_HOST double tan<double>(double x)
{
return std::tan(x);
};
template <typename T>
CK_TILE_HOST T atanh(T x)
{
return type_convert<T>(std::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atanh<float>(float x)
{
return std::atanhf(x);
};
template <>
CK_TILE_HOST double atanh<double>(double x)
{
return std::atanh(x);
};
template <typename T>
CK_TILE_HOST T sinh(T x)
{
return type_convert<T>(std::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sinh<float>(float x)
{
return std::sinhf(x);
};
template <>
CK_TILE_HOST double sinh<double>(double x)
{
return std::sinh(x);
};
template <typename T>
CK_TILE_HOST T ceil(T x)
{
return type_convert<T>(std::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float ceil<float>(float x)
{
return std::ceilf(x);
};
template <>
CK_TILE_HOST double ceil<double>(double x)
{
return std::ceil(x);
};
template <typename T>
CK_TILE_HOST T cosh(T x)
{
return type_convert<T>(std::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cosh<float>(float x)
{
return std::coshf(x);
};
template <>
CK_TILE_HOST double cosh<double>(double x)
{
return std::cosh(x);
};
template <typename T>
CK_TILE_HOST T floor(T x)
{
return type_convert<T>(std::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float floor<float>(float x)
{
return std::floorf(x);
};
template <>
CK_TILE_HOST double floor<double>(double x)
{
return std::floor(x);
};
template <typename T>
CK_TILE_HOST T rcp(T x)
{
return type_convert<T>(1.f / type_convert<float>(x));
};
template <typename T>
CK_TILE_HOST T exp(T x)
{
return type_convert<T>(std::expf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float exp<float>(float x)
{
return std::expf(x);
}
template <>
CK_TILE_HOST double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
CK_TILE_HOST T log(T x)
{
return type_convert<T>(std::logf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float log<float>(float x)
{
return std::logf(x);
}
template <>
CK_TILE_HOST double log<double>(double x)
{
return std::log(x);
}
template <typename T>
CK_TILE_HOST T pow(T x, T gamma)
{
return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
}
template <>
CK_TILE_HOST float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
CK_TILE_HOST double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
CK_TILE_HOST T expm1(T x)
{
return type_convert<T>(std::expm1f(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
CK_TILE_HOST double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE float abs(float x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
};
CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
CK_TILE_DEVICE int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_DEVICE int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
};
#endif
CK_TILE_DEVICE fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_DEVICE bool isnan(int32_t x)
{
(void)x;
return false;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_DEVICE bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
template <typename T>
CK_TILE_DEVICE T tanh(T x)
{
return type_convert<T>(::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tanh<float>(float x)
{
return ::tanhf(x);
};
template <>
CK_TILE_DEVICE double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
CK_TILE_DEVICE T acos(T x)
{
return type_convert<T>(::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acos<float>(float x)
{
return ::acosf(x);
};
template <>
CK_TILE_DEVICE double acos<double>(double x)
{
return ::acos(x);
};
template <typename T>
CK_TILE_DEVICE T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_DEVICE double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_DEVICE int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
{
return __hneg(x);
};
template <typename T>
CK_TILE_DEVICE T atan(T x)
{
return type_convert<T>(::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atan<float>(float x)
{
return ::atanf(x);
};
template <>
CK_TILE_DEVICE double atan<double>(double x)
{
return ::atan(x);
};
template <typename T>
CK_TILE_DEVICE T sin(T x)
{
return type_convert<T>(::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sin<float>(float x)
{
return ::sinf(x);
};
template <>
CK_TILE_DEVICE double sin<double>(double x)
{
return ::sin(x);
};
template <>
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
{
return ::hsin(x);
};
template <typename T>
CK_TILE_DEVICE T asin(T x)
{
return type_convert<T>(::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asin<float>(float x)
{
return ::asinf(x);
};
template <>
CK_TILE_DEVICE double asin<double>(double x)
{
return ::asin(x);
};
template <typename T>
CK_TILE_DEVICE T asinh(T x)
{
return type_convert<T>(::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asinh<float>(float x)
{
return ::asinhf(x);
};
template <>
CK_TILE_DEVICE double asinh<double>(double x)
{
return ::asinh(x);
};
template <typename T>
CK_TILE_DEVICE T acosh(T x)
{
return type_convert<T>(::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acosh<float>(float x)
{
return ::acoshf(x);
};
template <>
CK_TILE_DEVICE double acosh<double>(double x)
{
return ::acosh(x);
};
template <typename T>
CK_TILE_DEVICE T tan(T x)
{
return type_convert<T>(::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tan<float>(float x)
{
return ::tanf(x);
};
template <>
CK_TILE_DEVICE double tan<double>(double x)
{
return ::tan(x);
};
template <typename T>
CK_TILE_DEVICE T atanh(T x)
{
return type_convert<T>(::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atanh<float>(float x)
{
return ::atanhf(x);
};
template <>
CK_TILE_DEVICE double atanh<double>(double x)
{
return ::atanh(x);
};
template <typename T>
CK_TILE_DEVICE T sinh(T x)
{
return type_convert<T>(::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sinh<float>(float x)
{
return ::sinhf(x);
};
template <>
CK_TILE_DEVICE double sinh<double>(double x)
{
return ::sinh(x);
};
template <typename T>
CK_TILE_DEVICE T ceil(T x)
{
return type_convert<T>(::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float ceil<float>(float x)
{
return ::ceilf(x);
};
template <>
CK_TILE_DEVICE double ceil<double>(double x)
{
return ::ceil(x);
};
template <>
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
{
return ::hceil(x);
};
template <typename T>
CK_TILE_DEVICE T cosh(T x)
{
return type_convert<T>(::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float cosh<float>(float x)
{
return ::coshf(x);
};
template <>
CK_TILE_DEVICE double cosh<double>(double x)
{
return ::cosh(x);
};
template <typename T>
CK_TILE_DEVICE T floor(T x)
{
return type_convert<T>(::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float floor<float>(float x)
{
return ::floorf(x);
};
template <>
CK_TILE_DEVICE double floor<double>(double x)
{
return ::floor(x);
};
template <>
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
{
return ::hfloor(x);
};
template <typename T>
CK_TILE_DEVICE T rcp(T x)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return __frcp_rn(x);
#else
// return __ocml_native_recip_f32(x);
return __builtin_amdgcn_rcpf(x);
#endif
};
template <typename T>
CK_TILE_DEVICE T exp(T x)
{
return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
{
return hexp(x);
};
template <>
CK_TILE_DEVICE float exp<float>(float x)
{
return __ocml_exp_f32(x);
};
template <>
CK_TILE_DEVICE double exp<double>(double x)
{
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{
return type_convert<T>(__logf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
{
return hlog(x);
};
template <>
CK_TILE_DEVICE float log<float>(float x)
{
return __logf(x);
};
template <>
CK_TILE_DEVICE double log<double>(double x)
{
return log(x);
};
template <typename T>
CK_TILE_DEVICE T pow(T x, T gamma)
{
return type_convert<T>(powf(type_convert<float>(x), type_convert<float>(gamma)));
};
template <>
CK_TILE_DEVICE float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
CK_TILE_DEVICE double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
CK_TILE_DEVICE T expm1(T x)
{
return type_convert<T>(expm1f(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float expm1<float>(float x)
{
return expm1f(x);
};
template <>
CK_TILE_DEVICE double expm1<double>(double x)
{
return expm1(x);
};
} // namespace ck_tile } // namespace ck_tile
This diff is collapsed.
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp" #include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp" #include "ck_tile/core/tensor/null_tensor.hpp"
...@@ -28,7 +29,21 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT ...@@ -28,7 +29,21 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
NumCoord>& tile_window, NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
} }
template <typename T, template <typename T,
...@@ -46,7 +61,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -46,7 +61,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -66,7 +101,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -66,7 +101,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT ...@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// get input vectors // get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) { static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_array( constexpr auto idx_y_in = generate_tuple(
[&](auto ii) { [&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
}, },
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -24,5 +24,6 @@ ...@@ -24,5 +24,6 @@
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment