Unverified Commit 2cab8d39 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

CK Tile FA Training kernels (#1286)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 76827d82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KSKTSVR = 0,
QSKSVROGradS,
KSVR,
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename GemmDataType_,
typename LSEDataType_,
typename AccDataType_,
typename DDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename ODataType_,
typename OGradDataType_,
typename QGradDataType_,
typename KGradDataType_,
typename VGradDataType_,
typename BiasGradDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using GemmDataType = remove_cvref_t<GemmDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using KGradDataType = remove_cvref_t<KGradDataType_>;
using VGradDataType = remove_cvref_t<VGradDataType_>;
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename ODataType_,
typename OGradDataType_,
typename DDataType_,
index_t kBlockSize_,
index_t kVHeaddim_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaBwdOGradDotOPipelineProblem
{
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kVHeaddim = kVHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile
...@@ -13,6 +13,7 @@ template <typename QDataType_, ...@@ -13,6 +13,7 @@ template <typename QDataType_,
typename SaccDataType_, typename SaccDataType_,
typename SMPLComputeDataType_, typename SMPLComputeDataType_,
typename BiasDataType_, typename BiasDataType_,
typename RandValOutputDataType_,
typename LSEDataType_, typename LSEDataType_,
typename PDataType_, typename PDataType_,
typename OaccDataType_, typename OaccDataType_,
...@@ -23,19 +24,20 @@ template <typename QDataType_, ...@@ -23,19 +24,20 @@ template <typename QDataType_,
typename Traits_> typename Traits_>
struct BlockFmhaPipelineProblem struct BlockFmhaPipelineProblem
{ {
using QDataType = remove_cvref_t<QDataType_>; using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>; using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>; using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>; using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>; using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>; using BiasDataType = remove_cvref_t<BiasDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>; using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
using PDataType = remove_cvref_t<PDataType_>; using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>; using PDataType = remove_cvref_t<PDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using OaccDataType = remove_cvref_t<OaccDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using ODataType = remove_cvref_t<ODataType_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using Traits = remove_cvref_t<Traits_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
...@@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem ...@@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -14,19 +15,20 @@ namespace ck_tile { ...@@ -14,19 +15,20 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
struct BlockFmhaPipelineQRKSVS struct BlockFmhaPipelineQRKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
BlockDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS ...@@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(), v_dram_block_window_tmp.get_window_lengths(),
...@@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS ...@@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
block_sync_lds(); block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS ...@@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
BlockDropout& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr,
dropout);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,19 +13,20 @@ namespace ck_tile { ...@@ -13,19 +13,20 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct BlockFmhaPipelineQSKSVS struct BlockFmhaPipelineQSKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
......
This diff is collapsed.
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
BlockGemmARegBSmemCRegProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>, BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>; BlockGemmARegBSmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment