Commit 50540084 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into aosewski/gemm_tile_loop

parents 9c4d265f 2474dddb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool SrcResetCoordinateAfterRun,
bool DstResetCoordinateAfterRun>
struct ThreadwiseTensorSliceTransfer_v6r1r2
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const ElementwiseOperation& element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
// loop over space-filling curve
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
SrcData v;
// apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetCoordinateResetStep()
{
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;
};
} // namespace ck
...@@ -629,7 +629,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -629,7 +629,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{ {
static_assert( static_assert(
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
...@@ -682,6 +682,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -682,6 +682,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset, dst_wave_addr_offset,
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
} }
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
{ {
......
...@@ -13,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c); ...@@ -13,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c);
template <> template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c) __device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{ {
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) #if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
asm volatile("\n \ asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \ v_mac_f32 %0, %1, %2 \n \
" "
: "=v"(c) : "=v"(c)
: "v"(a), "v"(b), "0"(c)); : "v"(a), "v"(b), "0"(c));
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) #elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
asm volatile("\n \ asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \ v_fmac_f32 %0, %1, %2 \n \
" "
...@@ -76,14 +76,18 @@ template <> ...@@ -76,14 +76,18 @@ template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
#if defined(CK_USE_AMD_V_DOT2_F32_F16) #if defined(CK_USE_AMD_V_DOT2_F32_F16)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM #if CK_USE_AMD_V_DOT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
// ) s_nop with parameter 2 is equal to 3 x s_nop
asm volatile("\n \ asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \ v_dot2_f32_f16 %0, %1, %2, %0\n \
s_nop 2 \n \
" "
: "=v"(c) : "=v"(c)
: "v"(a), "v"(b), "0"(c)); : "v"(a), "v"(b), "0"(c));
#else #else
c = __builtin_amdgcn_sdot2(a, b, c, false); c = __builtin_amdgcn_fdot2(a, b, c, false);
#endif #endif
#else #else
const vector_type<half_t, 2> a_vector{a}; const vector_type<half_t, 2> a_vector{a};
...@@ -163,9 +167,13 @@ __device__ void ...@@ -163,9 +167,13 @@ __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{ {
#if defined(CK_USE_AMD_V_DOT4_I32_I8) #if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM #if CK_USE_AMD_V_DOT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
// ) s_nop with parameter 2 is equal to 3 x s_nop
asm volatile("\n \ asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \
s_nop 2 \n \
" "
: "=v"(c) : "=v"(c)
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c)); : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
......
...@@ -157,4 +157,76 @@ struct MagicDivision ...@@ -157,4 +157,76 @@ struct MagicDivision
} }
}; };
struct MDiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
multiplier = tmp[Number<0>{}];
shift = tmp[Number<1>{}];
}
__host__ __device__ MDiv() : divisor(0), multiplier(0), shift(0) {}
__host__ __device__ void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
multiplier = tmp[Number<0>{}];
shift = tmp[Number<1>{}];
}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
__host__ __device__ uint32_t get() const { return divisor; }
};
struct MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv2(uint32_t divisor_)
{
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
multiplier = tmp[Number<0>{}];
shift = tmp[Number<1>{}];
}
__host__ __device__ MDiv2() : multiplier(0), shift(0) {}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck } // namespace ck
...@@ -240,5 +240,21 @@ struct less ...@@ -240,5 +240,21 @@ struct less
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
}; };
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
return Y;
}
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two(Number<X> x)
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
return Number<Y>{};
}
} // namespace math } // namespace math
} // namespace ck } // namespace ck
...@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_ ...@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32 // convert bfp16 to int8 via fp32
template <> template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
......
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
namespace ck {
struct workgroup_barrier
{
__device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
__device__ uint32_t ld(uint32_t offset)
{
#if 0
float d = llvm_amdgcn_raw_buffer_load_fp32(
amdgcn_make_buffer_resource(base_ptr),
0,
offset,
AMDGCN_BUFFER_GLC);
union cvt {
float f32;
uint32_t u32;
};
cvt x;
x.f32 = d;
return x.u32;
#endif
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
}
__device__ void wait_eq(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0)
{
while(ld(offset) != value) {}
}
__syncthreads();
}
__device__ void wait_lt(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0)
{
while(ld(offset) < value) {}
}
__syncthreads();
}
__device__ void wait_set(uint32_t offset, uint32_t compare, uint32_t value)
{
if(threadIdx.x == 0)
{
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
}
__syncthreads();
}
// enter critical zoon, assume buffer is zero when launch kernel
__device__ void aquire(uint32_t offset) { wait_set(offset, 0, 1); }
// exit critical zoon, assume buffer is zero when launch kernel
__device__ void release(uint32_t offset) { wait_set(offset, 1, 0); }
__device__ void inc(uint32_t offset)
{
__syncthreads();
if(threadIdx.x == 0)
{
atomicAdd(base_ptr + offset, 1);
}
}
uint32_t* base_ptr;
};
} // namespace ck
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -335,3 +336,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -335,3 +336,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( ...@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#ifdef __int8__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1, std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC, NWC,
...@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( ...@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv2d backward data // conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef __int8__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv2d dl // conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( ...@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef __int8__
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( ...@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv3d backward data // conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( ...@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef __int8__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC, NDHWC,
...@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( ...@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
} }
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
} }
#endif
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> && else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>) is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
...@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
} }
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
#endif
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>) is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
...@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
} }
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
} }
#endif
} }
return op_ptrs; return op_ptrs;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmStreamK<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmStreamK<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if 0
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) ...@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
*/ */
struct DeviceMem struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size); DeviceMem(std::size_t mem_size);
void Realloc(std::size_t mem_size);
void* GetDeviceBuffer() const; void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment