Commit b7419eec authored by Jing Zhang's avatar Jing Zhang
Browse files

seperate float a/b

parent 57d0ea67
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = int8_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -23,12 +23,11 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali ...@@ -23,12 +23,11 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| CThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | Order| | | // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1, 1, 1, 4>, S<2, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, 2, 8, S<0, 1, 2, 3, 4, 5>, 5, 1>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 16, 4, 1, 1, 1, S<1, 1, 1, 4>, S<16, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -80,6 +80,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -80,6 +80,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using GridwiseGemm = using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType, ADataType,
BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -184,10 +185,17 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -184,10 +185,17 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float ave_time = 0; float ave_time = 0;
using ComputeType = ADataType;
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, true, true>; ADataType,
BDataType,
ComputeType,
CDataType,
true,
true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -206,8 +214,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -206,8 +214,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, true, false>; ADataType,
BDataType,
ComputeType,
CDataType,
true,
false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -226,8 +239,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -226,8 +239,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, false, true>; ADataType,
BDataType,
ComputeType,
CDataType,
false,
true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -246,8 +264,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -246,8 +264,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, false, false>; ADataType,
BDataType,
ComputeType,
CDataType,
false,
false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename ComputeType,
typename FloatC, typename FloatC,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
...@@ -27,8 +29,8 @@ __global__ void ...@@ -27,8 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dl_v1r3(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const index_t M, const index_t M,
const index_t N, const index_t N,
...@@ -38,9 +40,9 @@ __global__ void ...@@ -38,9 +40,9 @@ __global__ void
const index_t StrideC) const index_t StrideC)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ComputeType);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ ComputeType p_shared_block[shared_block_size];
const auto a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); const auto a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
const auto b_grid_desc_k0_n_k1 = GridwiseGemm::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); const auto b_grid_desc_k0_n_k1 = GridwiseGemm::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -68,7 +70,8 @@ __global__ void ...@@ -68,7 +70,8 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -121,7 +124,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -121,7 +124,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_aligned_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB); return 2 * (a_block_aligned_space_size) * sizeof(ComputeType);
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
...@@ -368,12 +371,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -368,12 +371,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using ComputeType = FloatA;
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, ComputeType* __restrict__ p_shared_block,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -423,6 +428,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -423,6 +428,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
"wrong!"); "wrong!");
ignore = a_global_buf;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
...@@ -431,8 +438,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -431,8 +438,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, ComputeType,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>, remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1), decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -451,8 +458,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -451,8 +458,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{})); make_tuple(Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy = auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB, ThreadwiseTensorSliceTransfer_v2<FloatB,
FloatAB, ComputeType,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>, remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_thread_desc_k0_n0_n1_k1), decltype(b_thread_desc_k0_n0_n1_k1),
Sequence<K0PerBlock, 1, NPerThread, K1.value>, Sequence<K0PerBlock, 1, NPerThread, K1.value>,
...@@ -470,8 +477,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -470,8 +477,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, ComputeType,
FloatAB, ComputeType,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc), decltype(b_k0_n_k1_thread_desc),
...@@ -489,12 +496,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -489,12 +496,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; ComputeType* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeType>(
b_k0_n_k1_thread_desc.GetElementSpaceSize()); b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeType>(
b_k0_n_k1_thread_desc.GetElementSpaceSize()); b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
...@@ -516,8 +523,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -516,8 +523,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); // a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); // a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf, b_global_buf,
...@@ -544,7 +551,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -544,7 +551,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_thread_slice_copy_step); b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); // a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf, b_global_buf,
...@@ -558,7 +565,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -558,7 +565,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); // a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
...@@ -568,7 +575,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -568,7 +575,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_thread_slice_copy_step); b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); // a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf, b_global_buf,
...@@ -582,7 +589,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -582,7 +589,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); // a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin += 2 * K0PerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock);
...@@ -598,7 +605,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -598,7 +605,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
block_sync_lds(); block_sync_lds();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); // a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1, b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf, b_global_buf,
...@@ -610,7 +617,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -610,7 +617,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); // a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds(); block_sync_lds();
......
...@@ -12,7 +12,8 @@ cmake ...@@ -12,7 +12,8 @@ cmake
-save-temps=$PWD" \ -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx1100" \ -D GPU_TARGETS="gfx90a" \
-D DL_KERNEL=ON \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment