Commit 061009a3 authored by aska-0096's avatar aska-0096
Browse files

Compile pass

parent d1894bdb
...@@ -315,7 +315,7 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -315,7 +315,7 @@ struct Blockwise_fpAintB_GemmWMMA
fast_numeric_converter; fast_numeric_converter;
// basic intrinsic to determine loopover direction // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat) if constexpr( 0 )
{ {
static_for<0, KPerBlock / WmmaK, 1>{}( static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
namespace ck {
/**
* @brief Blockwise data transfer with dequantization
*
* RunRead would load low-precision data and scale data.
* RunWrite would process dequantization process.
* Assume Scale is identical along K-dimension
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename SrcElementwiseOperation,
typename ScaleElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename BlockScaleSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename ScaleData,
typename DstData,
typename SrcDesc,
typename ScaleDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t ScaleScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t ScaleScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_dequant
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto scale_thread_slice_lengths = BlockScaleSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const ScaleDesc& scale_desc,
const Index& scale_block_slice_origin,
const ScaleElementwiseOperation& scale_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
scale_desc,
make_zero_multi_index<nDim>(),
scale_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<ScaleDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
is_same<BlockScaleSliceLengths, decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{} ,
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetScaleSliceOrigin(scale_desc,
scale_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
// With the assumption, scale scratch is always one
template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc,
const ScaleBuffer& scale_buf)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
// We don't prefer use this API directly
/*
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
*/
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
// With the assumption, scale buffer don't need move slice window method
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1_dequant<decltype(thread_slice_lengths),
decltype(scale_thread_slice_lengths),
SrcElementwiseOperation,
ScaleElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
ScaleData,
DstData,
SrcDesc,
ScaleDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
ScaleScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
ScaleScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
...@@ -66,7 +66,7 @@ template <typename ALayout, ...@@ -66,7 +66,7 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::dequant_v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::weight_only>
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
// LDS bypass feature not checked. // LDS bypass feature not implemented for dequantization pipeline.
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = true;
...@@ -677,7 +677,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -677,7 +677,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std::map<PipelineVersion, std::string> PipelineVersionToString{ std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}, {PipelineVersion::v2, "v2"},
{PipelineVersion::dequant_v1, "dequant_v1"}}; {PipelineVersion::dequant_v1, "dequant_v1"},
{PipelineVersion::weight_only, "weight_only"}};
// clang-format off // clang-format off
str << "DeviceFpAintBGemm_Wmma_CShuffle" str << "DeviceFpAintBGemm_Wmma_CShuffle"
......
...@@ -13,6 +13,7 @@ enum struct PipelineVersion ...@@ -13,6 +13,7 @@ enum struct PipelineVersion
v1, v1,
v2, v2,
dequant_v1, dequant_v1,
weight_only,
}; };
template <PipelineVersion PipelineVer, template <PipelineVersion PipelineVer,
...@@ -41,6 +42,10 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -41,6 +42,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{ {
return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{}; return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{};
} }
else if constexpr(PipelineVer == PipelineVersion::weight_only)
{
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
}
else else
{ {
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
......
...@@ -769,6 +769,109 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, false> ...@@ -769,6 +769,109 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, false>
} }
}; };
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_WeightOnly;
template <>
struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleGridBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleGridBuffer& scale_grid_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Global Prefetch Stage 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Scale read once
b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// Dequantization fused in blockwise_copy
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
template <index_t NumPrefetch> template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1; struct GridwiseGemmPipelineInterwave_v1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment