Unverified Commit cd167e49 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Compile for gfx908 and gfx90a (#130)

* adding compilation for multiple targets

* fix build

* clean

* update Jekinsfile

* update readme

* update Jenkins

* use ck::half_t instead of ushort for bf16

* rename enum classes

* clean

* rename

* clean
parent ecf337ba
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace ck { namespace ck {
enum struct DataTypeEnum_t enum struct DataTypeEnum
{ {
Half = 0, Half = 0,
Float = 1, Float = 1,
......
...@@ -6,35 +6,35 @@ ...@@ -6,35 +6,35 @@
namespace ck { namespace ck {
template <DataTypeEnum_t DataTypeEnum> template <DataTypeEnum DataTypeEnum>
struct get_datatype_from_enum; struct get_datatype_from_enum;
template <> template <>
struct get_datatype_from_enum<DataTypeEnum_t::Int8> struct get_datatype_from_enum<DataTypeEnum::Int8>
{ {
using type = int8_t; using type = int8_t;
}; };
template <> template <>
struct get_datatype_from_enum<DataTypeEnum_t::Int32> struct get_datatype_from_enum<DataTypeEnum::Int32>
{ {
using type = int32_t; using type = int32_t;
}; };
template <> template <>
struct get_datatype_from_enum<DataTypeEnum_t::Half> struct get_datatype_from_enum<DataTypeEnum::Half>
{ {
using type = half_t; using type = half_t;
}; };
template <> template <>
struct get_datatype_from_enum<DataTypeEnum_t::Float> struct get_datatype_from_enum<DataTypeEnum::Float>
{ {
using type = float; using type = float;
}; };
template <> template <>
struct get_datatype_from_enum<DataTypeEnum_t::Double> struct get_datatype_from_enum<DataTypeEnum::Double>
{ {
using type = double; using type = double;
}; };
...@@ -45,31 +45,31 @@ struct get_datatype_enum_from_type; ...@@ -45,31 +45,31 @@ struct get_datatype_enum_from_type;
template <> template <>
struct get_datatype_enum_from_type<int8_t> struct get_datatype_enum_from_type<int8_t>
{ {
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8; static constexpr DataTypeEnum value = DataTypeEnum::Int8;
}; };
template <> template <>
struct get_datatype_enum_from_type<int32_t> struct get_datatype_enum_from_type<int32_t>
{ {
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32; static constexpr DataTypeEnum value = DataTypeEnum::Int32;
}; };
template <> template <>
struct get_datatype_enum_from_type<half_t> struct get_datatype_enum_from_type<half_t>
{ {
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half; static constexpr DataTypeEnum value = DataTypeEnum::Half;
}; };
template <> template <>
struct get_datatype_enum_from_type<float> struct get_datatype_enum_from_type<float>
{ {
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float; static constexpr DataTypeEnum value = DataTypeEnum::Float;
}; };
template <> template <>
struct get_datatype_enum_from_type<double> struct get_datatype_enum_from_type<double>
{ {
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double; static constexpr DataTypeEnum value = DataTypeEnum::Double;
}; };
} // namespace ck } // namespace ck
......
#ifndef CK_BUFFER_HPP #pragma once
#define CK_BUFFER_HPP
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "config.hpp" #include "config.hpp"
...@@ -8,7 +6,7 @@ ...@@ -8,7 +6,7 @@
namespace ck { namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, template <AddressSpaceEnum BufferAddressSpace,
typename T, typename T,
typename ElementSpaceSize, typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue> bool InvalidElementUseNumericalZeroValue>
...@@ -34,7 +32,7 @@ struct DynamicBuffer ...@@ -34,7 +32,7 @@ struct DynamicBuffer
{ {
} }
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
...@@ -55,7 +53,7 @@ struct DynamicBuffer ...@@ -55,7 +53,7 @@ struct DynamicBuffer
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X should contain multiple T");
#if CK_USE_AMD_BUFFER_LOAD #if CK_USE_AMD_BUFFER_LOAD
bool constexpr use_amd_buffer_addressing = true; bool constexpr use_amd_buffer_addressing = true;
...@@ -63,7 +61,7 @@ struct DynamicBuffer ...@@ -63,7 +61,7 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing) if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
...@@ -81,50 +79,48 @@ struct DynamicBuffer ...@@ -81,50 +79,48 @@ struct DynamicBuffer
} }
else else
{ {
if constexpr(InvalidElementUseNumericalZeroValue) if(is_valid_element)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp; X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return is_valid_element ? tmp : X{0}; return tmp;
#else #else
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0}; return *c_style_pointer_cast<const X*>(&p_data_[i]);
#endif #endif
} }
else else
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS if constexpr(InvalidElementUseNumericalZeroValue)
X tmp; {
return X{0};
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); }
else
return is_valid_element ? tmp : X{invalid_element_value_}; {
#else return X{invalid_element_value_};
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) }
: X{invalid_element_value_};
#endif
} }
} }
} }
template <InMemoryDataOperationEnum_t Op, template <InMemoryDataOperationEnum Op,
typename X, typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
{ {
if constexpr(Op == InMemoryDataOperationEnum_t::Set) if constexpr(Op == InMemoryDataOperationEnum::Set)
{ {
this->template Set<X>(i, is_valid_element, x); this->template Set<X>(i, is_valid_element, x);
} }
else if constexpr(Op == InMemoryDataOperationEnum_t::AtomicAdd) else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd)
{ {
this->template AtomicAdd<X>(i, is_valid_element, x); this->template AtomicAdd<X>(i, is_valid_element, x);
} }
else if constexpr(Op == InMemoryDataOperationEnum_t::Add) else if constexpr(Op == InMemoryDataOperationEnum::Add)
{ {
auto tmp = this->template Get<X>(i, is_valid_element); auto tmp = this->template Get<X>(i, is_valid_element);
this->template Set<X>(i, is_valid_element, x + tmp); this->template Set<X>(i, is_valid_element, x + tmp);
...@@ -145,143 +141,120 @@ struct DynamicBuffer ...@@ -145,143 +141,120 @@ struct DynamicBuffer
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X should contain multiple T");
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{
#if CK_USE_AMD_BUFFER_STORE #if CK_USE_AMD_BUFFER_STORE
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; bool constexpr use_amd_buffer_addressing = true;
amd_buffer_store<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) bool constexpr use_amd_buffer_addressing = false;
{ #endif
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
bool constexpr workaround_int8_ds_write_issue = true;
#else #else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; bool constexpr workaround_int8_ds_write_issue = false;
#endif
}
#endif #endif
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
workaround_int8_ds_write_issue)
{ {
if(is_valid_element) if(is_valid_element)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value) static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8x16_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8_t>::value)
{ {
static_assert((is_same<remove_cvref_t<T>, int8_t>::value && // HACK: cast pointer of x is bad
is_same<remove_cvref_t<X>, int8_t>::value) || // TODO: remove this after compiler fix
(is_same<remove_cvref_t<T>, int8_t>::value && *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
is_same<remove_cvref_t<X>, int8x2_t>::value) || *c_style_pointer_cast<const int8_t*>(&x);
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8x16_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x2_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
} }
else else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x2_t>::value)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS // HACK: cast pointer of x is bad
X tmp = x; // TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); *c_style_pointer_cast<const int16_t*>(&x);
#else }
*c_style_pointer_cast<X*>(&p_data_[i]) = x; else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
#endif is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
} }
#endif
} }
} }
else else
...@@ -305,27 +278,49 @@ struct DynamicBuffer ...@@ -305,27 +278,49 @@ struct DynamicBuffer
bool>::type = false> bool>::type = false>
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
{ {
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size; constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
#if CK_USE_AMD_BUFFER_ATOMIC_ADD if constexpr(use_amd_buffer_addressing)
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else }
if(is_valid_element) else
{ {
atomicAdd(&p_data_[i], x); if(is_valid_element)
{
// FIXME: atomicAdd is defined by HIP, need to avoid implicit type casting when
// calling it
atomicAdd(c_style_pointer_cast<X*>(&p_data_[i]), x);
}
} }
#endif
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
...@@ -333,14 +328,14 @@ struct DynamicBuffer ...@@ -333,14 +328,14 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize> template <AddressSpaceEnum BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
} }
template < template <
AddressSpaceEnum_t BufferAddressSpace, AddressSpaceEnum BufferAddressSpace,
typename T, typename T,
typename ElementSpaceSize, typename ElementSpaceSize,
typename X, typename X,
...@@ -353,4 +348,3 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element ...@@ -353,4 +348,3 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
} }
} // namespace ck } // namespace ck
#endif
#ifndef CK_UTILITY_HPP #pragma once
#define CK_UTILITY_HPP
#include "config.hpp" #include "config.hpp"
namespace ck { namespace ck {
...@@ -16,5 +14,3 @@ __device__ index_t get_block_1d_id() { return blockIdx.x; } ...@@ -16,5 +14,3 @@ __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; } __device__ index_t get_grid_size() { return gridDim.x; }
} // namespace ck } // namespace ck
#endif
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX #if CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX
#include "array_multi_index.hpp" #include "array_multi_index.hpp"
#else #else
#include "statically_indexed_array_multi_index.hpp" #include "statically_indexed_array_multi_index.hpp"
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
namespace ck { namespace ck {
enum class ReduceTensorOp_t enum struct ReduceTensorOp
{ {
ADD = 0, ADD = 0,
MUL = 1, MUL = 1,
...@@ -41,19 +41,19 @@ enum class ReduceTensorOp_t ...@@ -41,19 +41,19 @@ enum class ReduceTensorOp_t
// MUL_NO_ZEROS = 8, // MUL_NO_ZEROS = 8,
}; };
enum class NanPropagation_t enum struct NanPropagation
{ {
NOT_PROPAGATE_NAN = 0, NOT_PROPAGATE_NAN = 0,
PROPAGATE_NAN = 1, PROPAGATE_NAN = 1,
}; };
enum class ReduceTensorIndices_t enum struct ReduceTensorIndices
{ {
NO_INDICES = 0, NO_INDICES = 0,
FLATTENED_INDICES = 1, FLATTENED_INDICES = 1,
}; };
enum class IndicesType_t enum struct IndicesType
{ {
INDICES_32BIT = 0, INDICES_32BIT = 0,
INDICES_64BIT = 1, INDICES_64BIT = 1,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
// static buffer for scalar // static buffer for scalar
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename T, typename T,
index_t N, index_t N,
bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
...@@ -17,10 +17,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -17,10 +17,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
{
return AddressSpace;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
...@@ -42,7 +39,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -42,7 +39,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
}; };
// static buffer for vector // static buffer for vector
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename S, typename S,
index_t NumOfVector, index_t NumOfVector,
index_t ScalarPerVector, index_t ScalarPerVector,
...@@ -59,10 +56,7 @@ struct StaticBufferTupleOfVector ...@@ -59,10 +56,7 @@ struct StaticBufferTupleOfVector
__host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
{
return AddressSpace;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
...@@ -158,7 +152,7 @@ struct StaticBufferTupleOfVector ...@@ -158,7 +152,7 @@ struct StaticBufferTupleOfVector
} }
}; };
template <AddressSpaceEnum_t AddressSpace, typename T, index_t N> template <AddressSpaceEnum AddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<AddressSpace, T, N, true>{}; return StaticBuffer<AddressSpace, T, N, true>{};
......
...@@ -7,7 +7,7 @@ namespace ck { ...@@ -7,7 +7,7 @@ namespace ck {
__device__ void block_sync_lds() __device__ void block_sync_lds()
{ {
#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\ asm volatile("\
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
s_barrier \ s_barrier \
......
...@@ -75,14 +75,14 @@ calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDes ...@@ -75,14 +75,14 @@ calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDes
} }
template <typename T> template <typename T>
inline auto activ(T v, const ck::ActivTypeEnum_t activ_type) inline auto activ(T v, const ck::ActivTypeEnum activ_type)
{ {
const T alpha = 0.3; const T alpha = 0.3;
switch(activ_type) switch(activ_type)
{ {
case ck::ActivTypeEnum_t::None: return v; case ck::ActivTypeEnum::None: return v;
case ck::ActivTypeEnum_t::LeakyRelu: return (v >= 0 ? v : alpha * v); case ck::ActivTypeEnum::LeakyRelu: return (v >= 0 ? v : alpha * v);
case ck::ActivTypeEnum_t::Sigmoid: return (1 / (1 + exp(-v))); case ck::ActivTypeEnum::Sigmoid: return (1 / (1 + exp(-v)));
default: throw std::runtime_error("unsupported activ type"); break; default: throw std::runtime_error("unsupported activ type"); break;
} }
} }
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "common_header.hpp"
template <typename TensorDesc> template <typename TensorDesc>
void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
......
...@@ -39,8 +39,8 @@ namespace ck { ...@@ -39,8 +39,8 @@ namespace ck {
namespace host_reduce { namespace host_reduce {
using ck::NanPropagation_t; using ck::NanPropagation;
using ck::ReduceTensorOp_t; using ck::ReduceTensorOp;
template <typename T> template <typename T>
static inline bool float_equal_one(T); static inline bool float_equal_one(T);
...@@ -66,44 +66,44 @@ static inline bool float_equal_zero(half_float::half x) ...@@ -66,44 +66,44 @@ static inline bool float_equal_zero(half_float::half x)
return x == static_cast<half_float::half>(0.0f); return x == static_cast<half_float::half>(0.0f);
}; };
template <typename AccDataType, ReduceTensorOp_t ReduceOpId> template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int) __host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{ {
using std::abs; using std::abs;
if constexpr(ReduceOpId == ReduceTensorOp_t::NORM1) if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{ {
return ([&](AccDataType& a_) { a_ = abs(a_); }); return ([&](AccDataType& a_) { a_ = abs(a_); });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::NORM2) else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{ {
return ([&](AccDataType& a_) { a_ = a_ * a_; }); return ([&](AccDataType& a_) { a_ = a_ * a_; });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::AMAX) else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{ {
return ([&](AccDataType& a_) { a_ = abs(a_); }); return ([&](AccDataType& a_) { a_ = abs(a_); });
} }
else else
{ {
// ReduceTensorOp_t::AVG: // ReduceTensorOp::AVG:
// ReduceTensorOp_t::ADD: // ReduceTensorOp::ADD:
// ReduceTensorOp_t::MUL: // ReduceTensorOp::MUL:
// ReduceTensorOp_t::MIN: // ReduceTensorOp::MIN:
// ReduceTensorOp_t::MAX: // ReduceTensorOp::MAX:
return ([&](AccDataType&) {}); return ([&](AccDataType&) {});
}; };
}; };
template <typename AccDataType, ReduceTensorOp_t ReduceOpId> template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider) __host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{ {
using std::sqrt; using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp_t::NORM2) if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{ {
return ([&](AccDataType& a_) { a_ = sqrt(a_); }); return ([&](AccDataType& a_) { a_ = sqrt(a_); });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::AVG) else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{ {
return ([&, divider](AccDataType& a_) { return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider)); a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
...@@ -111,36 +111,36 @@ __host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t di ...@@ -111,36 +111,36 @@ __host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t di
} }
else else
{ {
// ReduceTensorOp_t::ADD: // ReduceTensorOp::ADD:
// ReduceTensorOp_t::NORM1: // ReduceTensorOp::NORM1:
// ReduceTensorOp_t::MUL: // ReduceTensorOp::MUL:
// ReduceTensorOp_t::MIN: // ReduceTensorOp::MIN:
// ReduceTensorOp_t::MAX: // ReduceTensorOp::MAX:
// ReduceTensorOp_t::AMAX: // ReduceTensorOp::AMAX:
return ([&](AccDataType&) {}); return ([&](AccDataType&) {});
} }
}; };
template <typename AccDataType, ReduceTensorOp_t ReduceOpId> template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn() __host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{ {
if constexpr(ReduceOpId == ReduceTensorOp_t::ADD || ReduceOpId == ReduceTensorOp_t::AVG || if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp_t::NORM1 || ReduceOpId == ReduceTensorOp_t::NORM2) ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{ {
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; }); return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::MUL) else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{ {
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; }); return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::MIN) else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{ {
return ([&](AccDataType& a_, AccDataType b_) { return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_) if(a_ > b_)
a_ = b_; a_ = b_;
}); });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX || ReduceOpId == ReduceTensorOp_t::AMAX) else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{ {
return ([&](AccDataType& a_, AccDataType b_) { return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_) if(a_ < b_)
...@@ -149,10 +149,10 @@ __host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn ...@@ -149,10 +149,10 @@ __host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn
} }
}; };
template <typename AccDataType, ReduceTensorOp_t ReduceOpId> template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2() __host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{ {
if constexpr(ReduceOpId == ReduceTensorOp_t::MIN) if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{ {
return ([&](AccDataType& a_, AccDataType b_, bool& changed) { return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_) if(a_ > b_)
...@@ -164,7 +164,7 @@ __host__ static inline std::function<void(AccDataType&, AccDataType, bool& chang ...@@ -164,7 +164,7 @@ __host__ static inline std::function<void(AccDataType&, AccDataType, bool& chang
changed = false; changed = false;
}); });
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX || ReduceOpId == ReduceTensorOp_t::AMAX) else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{ {
return ([&](AccDataType& a_, AccDataType b_, bool& changed) { return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_) if(a_ < b_)
...@@ -178,40 +178,40 @@ __host__ static inline std::function<void(AccDataType&, AccDataType, bool& chang ...@@ -178,40 +178,40 @@ __host__ static inline std::function<void(AccDataType&, AccDataType, bool& chang
} }
else else
{ {
// ReduceTensorOp_t::ADD: // ReduceTensorOp::ADD:
// ReduceTensorOp_t::MUL: // ReduceTensorOp::MUL:
// ReduceTensorOp_t::AVG: // ReduceTensorOp::AVG:
// ReduceTensorOp_t::NORM1: // ReduceTensorOp::NORM1:
// ReduceTensorOp_t::NORM2: // ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{}); return (std::function<void(AccDataType&, AccDataType, bool&)>{});
}; };
}; };
template <typename AccDataType, ReduceTensorOp_t ReduceOpId> template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal() __host__ static inline AccDataType ReduceOpZeroVal()
{ {
if constexpr(ReduceOpId == ReduceTensorOp_t::MUL) if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{ {
return (static_cast<AccDataType>(1.0f)); return (static_cast<AccDataType>(1.0f));
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::MIN) else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{ {
return (std::numeric_limits<AccDataType>::max()); return (std::numeric_limits<AccDataType>::max());
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX) else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{ {
return (std::numeric_limits<AccDataType>::lowest()); return (std::numeric_limits<AccDataType>::lowest());
} }
else if constexpr(ReduceOpId == ReduceTensorOp_t::AMAX) else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{ {
return (static_cast<AccDataType>(0.0f)); return (static_cast<AccDataType>(0.0f));
} }
else else
{ {
// ReduceTensorOp_t::ADD // ReduceTensorOp::ADD
// ReduceTensorOp_t::AVG // ReduceTensorOp::AVG
// ReduceTensorOp_t::NORM1 // ReduceTensorOp::NORM1
// ReduceTensorOp_t::NORM2 // ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f)); return (static_cast<AccDataType>(0.0f));
}; };
}; };
......
...@@ -104,7 +104,7 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides, ...@@ -104,7 +104,7 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
ck::ReduceTensorOp_t ReduceOpId, ck::ReduceTensorOp ReduceOpId,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
bool PropagateNan, bool PropagateNan,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
ck::ActivTypeEnum_t activ_type, ck::ActivTypeEnum activ_type,
typename InLengths, typename InLengths,
typename WeiLengths, typename WeiLengths,
typename AddLengths, typename AddLengths,
......
...@@ -231,7 +231,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -231,7 +231,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -338,7 +338,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -338,7 +338,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -307,7 +307,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -307,7 +307,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -171,7 +171,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ ...@@ -171,7 +171,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -168,7 +168,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -168,7 +168,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -200,7 +200,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -200,7 +200,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -199,7 +199,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -199,7 +199,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment