Commit d43cd4ad authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Introduce gemm_softmax_gemm to codegen.

parent 3528a523
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <string> #include <string>
#include <sstream> #include <sstream>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#endif
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#ifndef __HIPCC_RTC__
struct BaseArgument struct BaseArgument
{ {
BaseArgument() = default; BaseArgument() = default;
...@@ -36,6 +38,7 @@ struct BaseInvoker ...@@ -36,6 +38,7 @@ struct BaseInvoker
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
#endif
struct BaseOperator struct BaseOperator
{ {
...@@ -43,6 +46,7 @@ struct BaseOperator ...@@ -43,6 +46,7 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
#ifndef __HIPCC_RTC__
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
...@@ -66,7 +70,7 @@ struct BaseOperator ...@@ -66,7 +70,7 @@ struct BaseOperator
assert(p_arg); assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
#endif
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#endif
#include "device_base.hpp" #include "device_base.hpp"
...@@ -28,6 +29,7 @@ template <typename ALayout, ...@@ -28,6 +29,7 @@ template <typename ALayout,
bool MaskOutUpperTriangle> // TODO: enum for mask type bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b0, const void* p_b0,
...@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator ...@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
} // namespace device } // namespace device
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <array> #include <array>
#endif
#include "ck/utility/array.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
...@@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
// GEMM: // GEMM:
...@@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator ...@@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -94,6 +99,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator ...@@ -94,6 +99,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
} // namespace device } // namespace device
......
...@@ -28,7 +28,7 @@ enum struct GemmSpecialization ...@@ -28,7 +28,7 @@ enum struct GemmSpecialization
NKOPadding, NKOPadding,
MNKOPadding, MNKOPadding,
}; };
#ifndef __HIPCC_RTC__
inline std::string getGemmSpecializationString(const GemmSpecialization& s) inline std::string getGemmSpecializationString(const GemmSpecialization& s)
{ {
switch(s) switch(s)
...@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s) ...@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
#endif
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +19,6 @@ ...@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -40,7 +42,7 @@ template <typename GridwiseGemm, ...@@ -40,7 +42,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
...@@ -430,6 +432,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -430,6 +432,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder.PadN, matrix_padder.PadN,
MaskOutUpperTriangle>; MaskOutUpperTriangle>;
#ifndef __HIPCC_RTC__
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -604,6 +607,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -604,6 +607,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
...@@ -611,6 +615,97 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -611,6 +615,97 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true; return true;
} }
static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -765,6 +860,269 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -765,6 +860,269 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str(); return str.str();
} }
#endif
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -14,8 +18,6 @@ ...@@ -14,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
...@@ -35,7 +37,7 @@ template <typename GridwiseGemm, ...@@ -35,7 +37,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
...@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, static auto MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const Array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const Array<index_t, NumDTensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -309,6 +311,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -309,6 +311,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using Block2ETileMap = using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
#ifndef __HIPCC_RTC__
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -498,6 +501,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -498,6 +501,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
}; };
#endif
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_) static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{ {
// check vector load/store // check vector load/store
...@@ -578,6 +583,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -578,6 +583,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return true; return true;
} }
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -676,11 +682,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -676,11 +682,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{ std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; { LoopScheduler::Interwave,
"Interwave" }};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"}, std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}}; { PipelineVersion::v2,
"v2" }};
// clang-format off // clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle" str << "DeviceGemmMultipleD_Xdl_CShuffle"
...@@ -709,6 +717,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -709,6 +717,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
#endif
template <class ADesc, class BDesc, class DsDesc, class EDesc> template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor struct Descriptor
...@@ -847,7 +856,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -847,7 +856,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType* __restrict__ p_e_grid) EDataType* __restrict__ p_e_grid)
{ {
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.IsValid()); assert(desc.IsValid());
#endif
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
GridwiseGemm::template Run<true>(p_a_grid, GridwiseGemm::template Run<true>(p_a_grid,
......
...@@ -13,6 +13,7 @@ enum struct MaskingSpecialization ...@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle MaskOutUpperTriangle
}; };
#ifndef __HIPCC_RTC__
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{ {
switch(s) switch(s)
...@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s ...@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
#endif
struct MaskDisabledPredicate struct MaskDisabledPredicate
{ {
...@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate ...@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate> template <typename MaskOutPredicate>
struct C0MatrixMask_impl struct C0MatrixMask_impl
{ {
__host__ __device__ C0MatrixMask_impl(index_t NRaw) __host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{}) : NRaw_(NRaw), predicate_(MaskOutPredicate{})
{ {
} }
......
...@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout ...@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
} // namespace convolution } // namespace convolution
#ifndef __HIPCC_RTC__
template < template <
typename Layout, typename Layout,
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false> typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
...@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&) ...@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os << Layout::name; os << Layout::name;
return os; return os;
} }
#endif
} // namespace tensor_layout } // namespace tensor_layout
} // namespace ck } // namespace ck
...@@ -340,8 +340,8 @@ struct Bilinear ...@@ -340,8 +340,8 @@ struct Bilinear
}; };
template <> template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const int8_t& y, const int32_t& x0, const int8_t& x1) const
{ {
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) + y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1)); beta_ * type_convert<float>(x1));
......
...@@ -466,7 +466,7 @@ struct FastGelu ...@@ -466,7 +466,7 @@ struct FastGelu
template <typename Y, typename X> template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const; __device__ void operator()(Y& y, const X& x) const;
#ifndef __HIPCC_RTC__
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -477,7 +477,7 @@ struct FastGelu ...@@ -477,7 +477,7 @@ struct FastGelu
const float emu = exp(u); const float emu = exp(u);
y = x / (1.f + emu); y = x / (1.f + emu);
} }
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp" // device code, use lower precision "__ocml_exp_f32" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
......
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef __HIPCC_RTC__
#include <limits> #include <limits>
#include <stdlib.h> #include <stdlib.h>
#endif
namespace ck { namespace ck {
...@@ -979,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -979,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return std::make_tuple(N0, M0, k_split); return ck::make_tuple(N0, M0, k_split);
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score = uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters ck::NumericLimits<int>::Max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++) tentative_sk_blocks++)
{ {
......
...@@ -475,9 +475,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -475,9 +475,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template <typename DsLayout, GemmSpecialization GemmSpec> template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const Array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const Array<index_t, NumDTensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -941,7 +941,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -941,7 +941,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K, const index_t K,
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs, const Array<index_t, NumDTensor> StrideDs,
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <ostream> #include <ostream>
#endif
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -53,12 +55,15 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -53,12 +55,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
else else
{ {
#ifndef __HIPCC_RTC__
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
#endif
} }
} }
} // namespace ck } // namespace ck
#ifndef __HIPCC_RTC__
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{ {
switch(p) switch(p)
...@@ -71,3 +76,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) ...@@ -71,3 +76,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
} }
return os; return os;
} }
#endif
...@@ -1005,6 +1005,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, ...@@ -1005,6 +1005,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t offset, index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
#ifndef __HIPCC_RTC__
template <typename T, index_t NumElemsPerThread> template <typename T, index_t NumElemsPerThread>
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
...@@ -1042,5 +1043,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -1042,5 +1043,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif #endif
} }
#endif
} // namespace ck } // namespace ck
...@@ -7,10 +7,12 @@ ...@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp" #include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#ifndef __HIPCC_RTC__
#include <array> #include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
#endif
namespace ck { namespace ck {
namespace detail { namespace detail {
...@@ -37,7 +39,7 @@ struct get_carrier<3> ...@@ -37,7 +39,7 @@ struct get_carrier<3>
{ {
using value_type = uint32_t; using value_type = uint32_t;
std::array<std::byte, 3> bytes; Array<ck::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type)); static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n() // replacement of host std::copy_n()
...@@ -61,22 +63,22 @@ struct get_carrier<3> ...@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure // method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept __device__ carrier(const carrier& other) noexcept
{ {
copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); copy_n(other.bytes.begin(), bytes.Size(), bytes.begin());
} }
public: public:
__device__ carrier& operator=(value_type value) noexcept __device__ carrier& operator=(value_type value) noexcept
{ {
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin()); copy_n(reinterpret_cast<const ck::byte*>(&value), bytes.Size(), bytes.begin());
return *this; return *this;
} }
__device__ operator value_type() const noexcept __device__ operator value_type() const noexcept
{ {
std::byte result[sizeof(value_type)]; ck::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result); copy_n(bytes.begin(), bytes.Size(), result);
return *reinterpret_cast<const value_type*>(result); return *reinterpret_cast<const value_type*>(result);
} }
...@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) ...@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{ {
constexpr unsigned object_size = sizeof(int64_t); constexpr unsigned object_size = sizeof(int64_t);
constexpr unsigned second_part_offset = object_size / 2; constexpr unsigned second_part_offset = object_size / 2;
auto* const from_obj = reinterpret_cast<const std::byte*>(&value); auto* const from_obj = reinterpret_cast<const ck::byte*>(&value);
alignas(int64_t) std::byte to_obj[object_size]; alignas(int64_t) ck::byte to_obj[object_size];
using Sgpr = uint32_t; using Sgpr = uint32_t;
...@@ -124,15 +126,15 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) ...@@ -124,15 +126,15 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
template < template <
typename Object, typename Object,
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>> typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
__device__ auto amd_wave_read_first_lane(const Object& obj) __device__ auto amd_wave_read_first_lane(const Object& obj)
{ {
using Size = unsigned; using Size = unsigned;
constexpr Size SgprSize = 4; constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object); constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj); auto* const from_obj = reinterpret_cast<const ck::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize]; alignas(Object) ck::byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize; constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
......
...@@ -38,6 +38,8 @@ struct Array ...@@ -38,6 +38,8 @@ struct Array
} }
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; } __host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
__host__ __device__ constexpr TData* begin() { return &mData[0]; }
__host__ __device__ constexpr TData* end() { return &mData[NSize]; }
}; };
// empty Array // empty Array
...@@ -54,7 +56,7 @@ template <typename X, typename... Xs> ...@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cvref_t<X>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...}; return Array<data_type, sizeof...(Xs) + 1>{ck::forward<X>(x), ck::forward<Xs>(xs)...};
} }
// make empty array // make empty array
......
...@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY> ...@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay) __host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay); [&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
} }
template <typename... X, typename... Y> template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty); [&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
} }
template <typename Container> template <typename Container>
......
...@@ -5,8 +5,25 @@ ...@@ -5,8 +5,25 @@
#include "ck/utility/statically_indexed_array.hpp" #include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
#endif // __HIPCC_RTC__
namespace ck { namespace ck {
#ifdef __HIPCC_RTC__
using byte = unsigned char;
#else
using std::byte;
#endif
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
...@@ -1060,6 +1077,146 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type; ...@@ -1060,6 +1077,146 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type; using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type; using uint8x64_t = typename vector_type<uint8_t, 64>::type;
#ifdef __HIPCC_RTC__
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
};
template <>
struct NumericLimits<half_t>
{
static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<int4_t>
{
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
};
#else
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
...@@ -1151,6 +1308,7 @@ struct NumericLimits<bf8_t> ...@@ -1151,6 +1308,7 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T> template <typename T>
struct NumericUtils struct NumericUtils
......
...@@ -4,11 +4,26 @@ ...@@ -4,11 +4,26 @@
#pragma once #pragma once
namespace ck { namespace ck {
#ifdef __HIPCC_RTC__
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
#else
template <bool B, typename T = void> template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>; using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void> template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type; using enable_if_t = typename std::enable_if<B, T>::type;
#endif
} // namespace ck } // namespace ck
...@@ -183,3 +183,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) ...@@ -183,3 +183,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
} }
} // namespace ck } // namespace ck
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment